From 37543c40f441155aac66f97276d79ff58a91a590 Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:01:31 -0700 Subject: [PATCH 01/17] Sync eng/common directory with azure-sdk-tools for PR 8974 (#37417) * updating package properties with direct/indirect (if named differently) as well as pulling BuildDocs from ci.yml artifact list if it exists * eliminate the addition of buildDocs property. it requires powershell-yaml to be present on our base function. not good * remove call to InitializeBuildDocs --------- Co-authored-by: Scott Beddall --- eng/common/scripts/Package-Properties.ps1 | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eng/common/scripts/Package-Properties.ps1 b/eng/common/scripts/Package-Properties.ps1 index e3cea7d392e..fabe342388e 100644 --- a/eng/common/scripts/Package-Properties.ps1 +++ b/eng/common/scripts/Package-Properties.ps1 @@ -15,6 +15,9 @@ class PackageProps [boolean]$IsNewSdk [string]$ArtifactName [string]$ReleaseStatus + # was this package purely included because other packages included it as an AdditionalValidationPackage? + [boolean]$IncludedForValidation + # does this package include other packages that we should trigger validation for? [string[]]$AdditionalValidationPackages PackageProps([string]$name, [string]$version, [string]$directoryPath, [string]$serviceDirectory) @@ -38,6 +41,7 @@ class PackageProps $this.Version = $version $this.DirectoryPath = $directoryPath $this.ServiceDirectory = $serviceDirectory + $this.IncludedForValidation = $false if (Test-Path (Join-Path $directoryPath "README.md")) { @@ -143,6 +147,7 @@ function Get-PrPkgProperties([string]$InputDiffJson) { $key = $addition.Replace($RepoRoot, "").TrimStart('\/') if ($lookup[$key]) { + $lookup[$key].IncludedForValidation = $true $packagesWithChanges += $lookup[$key] } } From 57316966c1d32c705da539e67259b79387a21bca Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 16 Sep 2024 16:55:43 -0700 Subject: [PATCH 02/17] Fixed the issue that `encryptionKey` was lost during serialization (#37410) * To fix #37251 * add test * update * update * Update changelog --- .../azure-search-documents/CHANGELOG.md | 79 ++++++++++++++++++- .../documents/indexes/models/_models.py | 12 ++- .../tests/test_models.py | 33 ++++++++ 3 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 sdk/search/azure-search-documents/tests/test_models.py diff --git a/sdk/search/azure-search-documents/CHANGELOG.md b/sdk/search/azure-search-documents/CHANGELOG.md index 7312c77318e..cfa1faa0e7b 100644 --- a/sdk/search/azure-search-documents/CHANGELOG.md +++ b/sdk/search/azure-search-documents/CHANGELOG.md @@ -1,16 +1,15 @@ # Release History -## 11.6.0b5 (Unreleased) +## 11.6.0b5 (2024-09-17) ### Features Added - `SearchIndexClient`.`get_search_client` inherits the API version. -### Breaking Changes - ### Bugs Fixed - Fixed the issue that we missed ODATA header when using Entra ID auth. +- Fixed the issue that `encryptionKey` was lost during serialization. #37251 ### Other Changes @@ -32,6 +31,80 @@ - `azure.search.documents.indexes.models.VectorSearchProfile.vectorizer` -> `azure.search.documents.indexes.models.VectorSearchProfile.vectorizer_name` - `azure.search.documents.indexes.models.VectorSearchVectorizer.name` -> `azure.search.documents.indexes.models.VectorSearchVectorizer.vectorizer_name` +## 11.5.1 (2024-07-30) + +### Other Changes + +- Improved type checks. + +## 11.5.0 (2024-07-16) + +### Breaking Changes + +> These changes do not impact the API of stable versions such as 11.4.0. +> Only code written against a beta version such as 11.6.0b4 may be affected. +- Below models are renamed + - `azure.search.documents.indexes.models.SearchIndexerIndexProjections` -> `azure.search.documents.indexes.models.SearchIndexerIndexProjection` + - `azure.search.documents.indexes.models.LineEnding` -> `azure.search.documents.indexes.models.OrcLineEnding` + - `azure.search.documents.indexes.models.ScalarQuantizationCompressionConfiguration` -> `azure.search.documents.indexes.models.ScalarQuantizationCompression` + - `azure.search.documents.indexes.models.VectorSearchCompressionConfiguration` -> `azure.search.documents.indexes.models.VectorSearchCompression` + - `azure.search.documents.indexes.models.VectorSearchCompressionTargetDataType` -> `azure.search.documents.indexes.models.VectorSearchCompressionTarget` + +- Below models do not exist in this release + - `azure.search.documents.models.QueryLanguage` + - `azure.search.documents.models.QuerySpellerType` + - `azure.search.documents.models.QueryDebugMode` + - `azure.search.documents.models.HybridCountAndFacetMode` + - `azure.search.documents.models.HybridSearch` + - `azure.search.documents.models.SearchScoreThreshold` + - `azure.search.documents.models.VectorSimilarityThreshold` + - `azure.search.documents.models.VectorThreshold` + - `azure.search.documents.models.VectorThresholdKind` + - `azure.search.documents.models.VectorizableImageBinaryQuery` + - `azure.search.documents.models.VectorizableImageUrlQuery` + - `azure.search.documents.indexes.models.SearchAlias` + - `azure.search.documents.indexes.models.AIServicesVisionParameters` + - `azure.search.documents.indexes.models.AIServicesVisionVectorizer` + - `azure.search.documents.indexes.models.AIStudioModelCatalogName` + - `azure.search.documents.indexes.models.AzureMachineLearningParameters` + - `azure.search.documents.indexes.models.AzureMachineLearningSkill` + - `azure.search.documents.indexes.models.AzureMachineLearningVectorizer` + - `azure.search.documents.indexes.models.CustomVectorizer` + - `azure.search.documents.indexes.models.CustomNormalizer` + - `azure.search.documents.indexes.models.DocumentKeysOrIds` + - `azure.search.documents.indexes.models.IndexingMode` + - `azure.search.documents.indexes.models.LexicalNormalizer` + - `azure.search.documents.indexes.models.LexicalNormalizerName` + - `azure.search.documents.indexes.models.NativeBlobSoftDeleteDeletionDetectionPolicy` + - `azure.search.documents.indexes.models.SearchIndexerCache` + - `azure.search.documents.indexes.models.SkillNames` + - `azure.search.documents.indexes.models.VisionVectorizeSkill` + +- SearchAlias operations do not exist in this release +- `SearchIndexerClient.reset_documents` does not exist in this release +- `SearchIndexerClient.reset_skills` does not exist in this release + +- Below properties do not exist + - `azure.search.documents.indexes.models.SearchIndexerDataSourceConnection.identity` + - `azure.search.documents.indexes.models.SearchIndex.normalizers` + - `azure.search.documents.indexes.models.SearchField.normalizer_name` + +- Below parameters do not exist + - `SearchClient.search.debug` + - `SearchClient.search.hybrid_search` + - `SearchClient.search.query_language` + - `SearchClient.search.query_speller` + - `SearchClient.search.semantic_fields` + - `SearchIndexerClient.create_or_update_indexer.skip_indexer_reset_requirement_for_cache` + - `SearchIndexerClient.create_or_update_data_source_connection.skip_indexer_reset_requirement_for_cache` + - `SearchIndexerClient.create_or_update_skillset.skip_indexer_reset_requirement_for_cache` + - `SearchIndexerClient.create_or_update_indexer.disable_cache_reprocessing_change_detection` + - `SearchIndexerClient.create_or_update_skillset.disable_cache_reprocessing_change_detection` + +### Other Changes + +- Updated default API version to `2024-07-01`. + ## 11.6.0b4 (2024-05-07) ### Features Added diff --git a/sdk/search/azure-search-documents/azure/search/documents/indexes/models/_models.py b/sdk/search/azure-search-documents/azure/search/documents/indexes/models/_models.py index 1a2cca3bcf9..ea89fcb569d 100644 --- a/sdk/search/azure-search-documents/azure/search/documents/indexes/models/_models.py +++ b/sdk/search/azure-search-documents/azure/search/documents/indexes/models/_models.py @@ -1121,7 +1121,9 @@ class SearchIndexerDataSourceConnection(_serialization.Model): data_change_detection_policy=self.data_change_detection_policy, data_deletion_detection_policy=self.data_deletion_detection_policy, e_tag=self.e_tag, - encryption_key=self.encryption_key, + encryption_key=( + self.encryption_key._to_generated() if self.encryption_key else None # pylint: disable=protected-access + ), identity=self.identity, ) @@ -1141,7 +1143,13 @@ class SearchIndexerDataSourceConnection(_serialization.Model): data_change_detection_policy=search_indexer_data_source.data_change_detection_policy, data_deletion_detection_policy=search_indexer_data_source.data_deletion_detection_policy, e_tag=search_indexer_data_source.e_tag, - encryption_key=search_indexer_data_source.encryption_key, + encryption_key=( + SearchResourceEncryptionKey._from_generated( # pylint: disable=protected-access + search_indexer_data_source.encryption_key + ) + if search_indexer_data_source.encryption_key + else None + ), identity=search_indexer_data_source.identity, ) diff --git a/sdk/search/azure-search-documents/tests/test_models.py b/sdk/search/azure-search-documents/tests/test_models.py new file mode 100644 index 00000000000..eddef0d60d5 --- /dev/null +++ b/sdk/search/azure-search-documents/tests/test_models.py @@ -0,0 +1,33 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +from azure.search.documents.indexes.models import ( + SearchIndexerDataContainer, + SearchIndexerDataSourceConnection, + SearchResourceEncryptionKey, +) + + +def test_encryption_key_serialization(): + from azure.search.documents.indexes._generated.models import ( + SearchResourceEncryptionKey as SearchResourceEncryptionKeyGen, + ) + + container = SearchIndexerDataContainer(name="container_name") + encryption_key = SearchResourceEncryptionKey( + key_name="key", + key_version="key_version", + vault_uri="vault_uri", + application_id="application_id", + ) + data_source_connection = SearchIndexerDataSourceConnection( + name="datasource-name", + type="azureblob", + connection_string="connection_string", + container=container, + encryption_key=encryption_key, + ) + packed_data_source = data_source_connection._to_generated() + assert isinstance(packed_data_source.encryption_key, SearchResourceEncryptionKeyGen) From e4cabf3ff171e98e278b7a4f4f0fda912bc84fae Mon Sep 17 00:00:00 2001 From: Neehar Duvvuri <40341266+needuv@users.noreply.github.com> Date: Mon, 16 Sep 2024 20:09:05 -0400 Subject: [PATCH 03/17] Remove jsonpath-ng Dependency (#37418) --- sdk/evaluation/azure-ai-evaluation/setup.py | 1 - shared_requirements.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index af73fcd6e2d..3d0cdb27873 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -68,7 +68,6 @@ setup( "promptflow-devkit>=1.15.0", "promptflow-core>=1.15.0", "websocket-client>=1.2.0", - "jsonpath_ng>=1.5.0", "numpy>=1.23.2; python_version<'3.12'", "numpy>=1.26.4; python_version>='3.12'", "pyjwt>=2.8.0", diff --git a/shared_requirements.txt b/shared_requirements.txt index 295708bf1d7..e9fc401f0cd 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -71,6 +71,5 @@ dnspython promptflow-core promptflow-devkit numpy -jsonpath-ng nltk rouge-score \ No newline at end of file From c34465e4d94dffaef62a2e1545f8b4638d444367 Mon Sep 17 00:00:00 2001 From: Peter Wu <162184229+weirongw23-msft@users.noreply.github.com> Date: Mon, 16 Sep 2024 22:13:48 -0400 Subject: [PATCH 04/17] [Storage] [STG 95] STG 95 GA Release Changelogs (#37414) --- sdk/storage/azure-storage-blob/CHANGELOG.md | 5 ++++- sdk/storage/azure-storage-file-datalake/CHANGELOG.md | 4 ++-- sdk/storage/azure-storage-file-share/CHANGELOG.md | 4 ++-- sdk/storage/azure-storage-queue/CHANGELOG.md | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 06e92777e76..da94c657e8f 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -1,9 +1,12 @@ # Release History -## 12.23.0 (Unreleased) +## 12.23.0 (2024-09-17) ### Features Added +- Stable release of features from 12.23.0b1 +### Bugs Fixed +- Fixed an issue with batch APIs when using Azurite. ## 12.23.0b1 (2024-08-07) diff --git a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md index 7ae4690465c..e059c2697d0 100644 --- a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md @@ -1,9 +1,9 @@ # Release History -## 12.17.0 (Unreleased) +## 12.17.0 (2024-09-17) ### Features Added - +- Stable release of features from 12.17.0b1 ## 12.17.0b1 (2024-08-07) diff --git a/sdk/storage/azure-storage-file-share/CHANGELOG.md b/sdk/storage/azure-storage-file-share/CHANGELOG.md index a267a6f7ac7..a6be53d879f 100644 --- a/sdk/storage/azure-storage-file-share/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-share/CHANGELOG.md @@ -1,9 +1,9 @@ # Release History -## 12.18.0 (Unreleased) +## 12.18.0 (2024-09-17) ### Features Added - +- Stable release of features from 12.18.0b1 ## 12.18.0b1 (2024-08-07) diff --git a/sdk/storage/azure-storage-queue/CHANGELOG.md b/sdk/storage/azure-storage-queue/CHANGELOG.md index 9656a24deb1..76ba40926a0 100644 --- a/sdk/storage/azure-storage-queue/CHANGELOG.md +++ b/sdk/storage/azure-storage-queue/CHANGELOG.md @@ -1,9 +1,9 @@ # Release History -## 12.12.0 (Unreleased) +## 12.12.0 (2024-09-17) ### Features Added - +- Stable release of features from 12.12.0b1 ## 12.12.0b1 (2024-08-07) From 9b37e9bc9370922e9d59f60acb9d780f64e3ba6a Mon Sep 17 00:00:00 2001 From: Waqas Javed <7674577+w-javed@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:11:15 -0700 Subject: [PATCH 05/17] Azure-Ai-Gen-Pkg-version-change-to-b9 (#37416) * version-change-to-b9 * version-change-to-b9 --- sdk/ai/azure-ai-generative/CHANGELOG.md | 5 +++++ sdk/ai/azure-ai-generative/azure/ai/generative/_version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sdk/ai/azure-ai-generative/CHANGELOG.md b/sdk/ai/azure-ai-generative/CHANGELOG.md index 14ebe626f9d..5a90cfbfe6c 100644 --- a/sdk/ai/azure-ai-generative/CHANGELOG.md +++ b/sdk/ai/azure-ai-generative/CHANGELOG.md @@ -1,5 +1,10 @@ # Release History +## 1.0.0b9 (2024-09-16) + +### Bugs Fixed +security bug - code injection + ## 1.0.0b8 (2024-03-27) ### Other Changes diff --git a/sdk/ai/azure-ai-generative/azure/ai/generative/_version.py b/sdk/ai/azure-ai-generative/azure/ai/generative/_version.py index 1b0058b20ee..03f444e3024 100644 --- a/sdk/ai/azure-ai-generative/azure/ai/generative/_version.py +++ b/sdk/ai/azure-ai-generative/azure/ai/generative/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -VERSION = "1.0.0b8" +VERSION = "1.0.0b9" From 807ef332799c488e17db9685d5d2276eae035d64 Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:51:17 -0700 Subject: [PATCH 06/17] Update CodeownersLinter for net6 to net8 update (#37431) Co-authored-by: James Suplizio --- eng/common/pipelines/codeowners-linter.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/common/pipelines/codeowners-linter.yml b/eng/common/pipelines/codeowners-linter.yml index b3e6ec1cd95..f815c4944ff 100644 --- a/eng/common/pipelines/codeowners-linter.yml +++ b/eng/common/pipelines/codeowners-linter.yml @@ -31,7 +31,7 @@ stages: vmImage: ubuntu-22.04 variables: - CodeownersLinterVersion: '1.0.0-dev.20240614.4' + CodeownersLinterVersion: '1.0.0-dev.20240917.2' DotNetDevOpsFeed: "https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json" RepoLabelUri: "https://azuresdkartifacts.blob.core.windows.net/azure-sdk-write-teams/repository-labels-blob" TeamUserUri: "https://azuresdkartifacts.blob.core.windows.net/azure-sdk-write-teams/azure-sdk-write-teams-blob" From a60b09ecf04a2789f58c83efc7acb19a0ccffe89 Mon Sep 17 00:00:00 2001 From: Peter Wu <162184229+weirongw23-msft@users.noreply.github.com> Date: Tue, 17 Sep 2024 19:24:48 -0400 Subject: [PATCH 07/17] bump versions after stg 95 ga release (#37437) --- sdk/storage/azure-storage-blob/CHANGELOG.md | 4 ++++ sdk/storage/azure-storage-blob/azure/storage/blob/_version.py | 2 +- sdk/storage/azure-storage-blob/setup.py | 2 +- sdk/storage/azure-storage-file-datalake/CHANGELOG.md | 4 ++++ .../azure/storage/filedatalake/_version.py | 2 +- sdk/storage/azure-storage-file-datalake/setup.py | 4 ++-- sdk/storage/azure-storage-file-share/CHANGELOG.md | 4 ++++ .../azure/storage/fileshare/_version.py | 2 +- sdk/storage/azure-storage-file-share/setup.py | 2 +- sdk/storage/azure-storage-queue/CHANGELOG.md | 4 ++++ .../azure-storage-queue/azure/storage/queue/_version.py | 2 +- sdk/storage/azure-storage-queue/setup.py | 2 +- 12 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index da94c657e8f..48c9cb22f45 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +## 12.24.0b1 (Unreleased) + +### Features Added + ## 12.23.0 (2024-09-17) ### Features Added diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_version.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_version.py index 9bdabb44022..f67466f0741 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_version.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_version.py @@ -4,4 +4,4 @@ # license information. # -------------------------------------------------------------------------- -VERSION = "12.23.0" +VERSION = "12.24.0b1" diff --git a/sdk/storage/azure-storage-blob/setup.py b/sdk/storage/azure-storage-blob/setup.py index b011968e8e1..d21b26946e6 100644 --- a/sdk/storage/azure-storage-blob/setup.py +++ b/sdk/storage/azure-storage-blob/setup.py @@ -56,7 +56,7 @@ setup( url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob', keywords="azure, azure sdk", classifiers=[ - 'Development Status :: 5 - Production/Stable', + 'Development Status :: 4 - Beta', 'Programming Language :: Python', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3', diff --git a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md index e059c2697d0..ec63391ba72 100644 --- a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +## 12.18.0b1 (Unreleased) + +### Features Added + ## 12.17.0 (2024-09-17) ### Features Added diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_version.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_version.py index 64bfa0b4f1a..e2a8a2d4a09 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_version.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_version.py @@ -4,4 +4,4 @@ # license information. # -------------------------------------------------------------------------- -VERSION = "12.17.0" +VERSION = "12.18.0b1" diff --git a/sdk/storage/azure-storage-file-datalake/setup.py b/sdk/storage/azure-storage-file-datalake/setup.py index 19462196b45..473b7d05d6a 100644 --- a/sdk/storage/azure-storage-file-datalake/setup.py +++ b/sdk/storage/azure-storage-file-datalake/setup.py @@ -57,7 +57,7 @@ setup( url='https://github.com/Azure/azure-sdk-for-python', keywords="azure, azure sdk", classifiers=[ - 'Development Status :: 5 - Production/Stable', + 'Development Status :: 4 - Beta', 'Programming Language :: Python', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3', @@ -78,7 +78,7 @@ setup( python_requires=">=3.8", install_requires=[ "azure-core>=1.30.0", - "azure-storage-blob>=12.23.0", + "azure-storage-blob>=12.24.0b1", "typing-extensions>=4.6.0", "isodate>=0.6.1" ], diff --git a/sdk/storage/azure-storage-file-share/CHANGELOG.md b/sdk/storage/azure-storage-file-share/CHANGELOG.md index a6be53d879f..cb5900c1cb9 100644 --- a/sdk/storage/azure-storage-file-share/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-share/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +## 12.19.0b1 (Unreleased) + +### Features Added + ## 12.18.0 (2024-09-17) ### Features Added diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_version.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_version.py index dc5b854a0ee..867c3e64caf 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_version.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_version.py @@ -4,4 +4,4 @@ # license information. # -------------------------------------------------------------------------- -VERSION = "12.18.0" +VERSION = "12.19.0b1" diff --git a/sdk/storage/azure-storage-file-share/setup.py b/sdk/storage/azure-storage-file-share/setup.py index 832fc9c7d6a..c5b2b5bd752 100644 --- a/sdk/storage/azure-storage-file-share/setup.py +++ b/sdk/storage/azure-storage-file-share/setup.py @@ -45,7 +45,7 @@ setup( url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-file-share', keywords="azure, azure sdk", classifiers=[ - 'Development Status :: 5 - Production/Stable', + 'Development Status :: 4 - Beta', 'Programming Language :: Python', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3', diff --git a/sdk/storage/azure-storage-queue/CHANGELOG.md b/sdk/storage/azure-storage-queue/CHANGELOG.md index 76ba40926a0..138c589cb22 100644 --- a/sdk/storage/azure-storage-queue/CHANGELOG.md +++ b/sdk/storage/azure-storage-queue/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +## 12.13.0b1 (Unreleased) + +### Features Added + ## 12.12.0 (2024-09-17) ### Features Added diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_version.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_version.py index 2962b1a7ddd..e49d53fd26e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_version.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_version.py @@ -4,4 +4,4 @@ # license information. # -------------------------------------------------------------------------- -VERSION = "12.12.0" +VERSION = "12.13.0b1" diff --git a/sdk/storage/azure-storage-queue/setup.py b/sdk/storage/azure-storage-queue/setup.py index ee2dda40b48..1b14673eb86 100644 --- a/sdk/storage/azure-storage-queue/setup.py +++ b/sdk/storage/azure-storage-queue/setup.py @@ -46,7 +46,7 @@ setup( url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-queue', keywords="azure, azure sdk", classifiers=[ - 'Development Status :: 5 - Production/Stable', + 'Development Status :: 4 - Beta', 'Programming Language :: Python', "Programming Language :: Python :: 3 :: Only", 'Programming Language :: Python :: 3', From ddd5c27661795add88f465399257f730ed2bf0db Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 17 Sep 2024 17:24:15 -0700 Subject: [PATCH 08/17] [Identity] Implement new protocol for all credentials (#36882) All credentials now implement the `SupportsTokenInfo/AsyncSupportsTokenInfo` protocol, by each having a `get_token_info` method implementation. This allows for more extensible authentication constructs. Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 4 + sdk/identity/azure-identity/assets.json | 2 +- .../azure/identity/_bearer_token_provider.py | 4 +- .../identity/_credentials/application.py | 33 ++- .../_credentials/authorization_code.py | 31 +- .../azure/identity/_credentials/azd_cli.py | 39 ++- .../azure/identity/_credentials/azure_cli.py | 43 ++- .../identity/_credentials/azure_pipelines.py | 21 +- .../identity/_credentials/azure_powershell.py | 42 ++- .../azure/identity/_credentials/chained.py | 114 ++++++- .../identity/_credentials/client_assertion.py | 6 +- .../azure/identity/_credentials/default.py | 44 ++- .../identity/_credentials/environment.py | 30 +- .../azure/identity/_credentials/imds.py | 8 +- .../identity/_credentials/managed_identity.py | 35 ++- .../identity/_credentials/on_behalf_of.py | 14 +- .../identity/_credentials/shared_cache.py | 75 ++++- .../azure/identity/_credentials/silent.py | 56 +++- .../azure/identity/_credentials/vscode.py | 37 ++- .../azure/identity/_internal/aad_client.py | 18 +- .../identity/_internal/aad_client_base.py | 21 +- .../_internal/client_credential_base.py | 20 +- .../identity/_internal/get_token_mixin.py | 66 ++++- .../azure/identity/_internal/interactive.py | 91 +++++- .../_internal/managed_identity_base.py | 11 +- .../_internal/managed_identity_client.py | 28 +- .../_internal/msal_managed_identity_client.py | 66 ++++- .../identity/_internal/shared_token_cache.py | 7 +- .../identity/aio/_credentials/application.py | 36 ++- .../aio/_credentials/authorization_code.py | 33 ++- .../identity/aio/_credentials/azd_cli.py | 39 ++- .../identity/aio/_credentials/azure_cli.py | 38 ++- .../aio/_credentials/azure_pipelines.py | 21 +- .../aio/_credentials/azure_powershell.py | 38 ++- .../identity/aio/_credentials/certificate.py | 6 +- .../identity/aio/_credentials/chained.py | 104 ++++++- .../aio/_credentials/client_assertion.py | 6 +- .../aio/_credentials/client_secret.py | 6 +- .../identity/aio/_credentials/default.py | 49 ++- .../identity/aio/_credentials/environment.py | 31 +- .../azure/identity/aio/_credentials/imds.py | 10 +- .../aio/_credentials/managed_identity.py | 38 ++- .../identity/aio/_credentials/on_behalf_of.py | 6 +- .../identity/aio/_credentials/shared_cache.py | 59 +++- .../azure/identity/aio/_credentials/vscode.py | 37 ++- .../identity/aio/_internal/aad_client.py | 20 +- .../identity/aio/_internal/get_token_mixin.py | 66 ++++- .../aio/_internal/managed_identity_base.py | 11 +- .../aio/_internal/managed_identity_client.py | 4 +- sdk/identity/azure-identity/setup.py | 4 +- sdk/identity/azure-identity/tests/helpers.py | 8 +- .../azure-identity/tests/helpers_async.py | 9 - .../azure-identity/tests/test_aad_client.py | 6 +- .../tests/test_aad_client_async.py | 2 +- .../tests/test_app_service_async.py | 9 +- .../tests/test_application_credential.py | 46 +-- .../test_application_credential_async.py | 37 ++- .../azure-identity/tests/test_auth_code.py | 81 +++-- .../tests/test_auth_code_async.py | 71 +++-- .../azure-identity/tests/test_authority.py | 5 +- .../tests/test_azd_cli_credential.py | 103 ++++--- .../tests/test_azd_cli_credential_async.py | 103 ++++--- .../tests/test_azure_application.py | 5 +- .../azure-identity/tests/test_azure_arc.py | 7 +- .../tests/test_azure_pipelines_credential.py | 19 +- .../test_azure_pipelines_credential_async.py | 19 +- .../tests/test_bearer_token_provider.py | 12 +- .../tests/test_bearer_token_provider_async.py | 14 +- .../tests/test_browser_credential.py | 57 ++-- .../tests/test_certificate_credential.py | 119 +++++--- .../test_certificate_credential_async.py | 100 ++++--- .../tests/test_chained_credential.py | 199 ++++++++++--- .../test_chained_token_credential_async.py | 205 ++++++++++--- .../tests/test_cli_credential.py | 119 +++++--- .../tests/test_cli_credential_async.py | 119 +++++--- .../tests/test_client_assertion_credential.py | 13 +- .../test_client_assertion_credential_async.py | 12 +- .../tests/test_client_secret_credential.py | 129 +++++--- .../test_client_secret_credential_async.py | 124 +++++--- .../tests/test_context_manager.py | 5 +- .../azure-identity/tests/test_default.py | 52 ++-- .../tests/test_default_async.py | 50 ++-- .../tests/test_device_code_credential.py | 58 ++-- .../tests/test_environment_credential.py | 9 +- .../test_environment_credential_async.py | 14 +- .../tests/test_get_token_mixin.py | 61 ++-- .../tests/test_get_token_mixin_async.py | 61 ++-- .../tests/test_imds_credential.py | 75 ++--- .../tests/test_imds_credential_async.py | 145 +++++---- .../tests/test_initialization.py | 74 +++++ .../tests/test_initialization_async.py | 69 +++++ .../tests/test_interactive_credential.py | 111 +++++-- .../azure-identity/tests/test_live.py | 56 ++-- .../azure-identity/tests/test_live_async.py | 46 +-- .../tests/test_managed_identity.py | 187 +++++++----- .../tests/test_managed_identity_async.py | 231 +++++++++------ .../tests/test_multi_tenant_auth.py | 11 +- .../tests/test_multi_tenant_auth_async.py | 11 +- sdk/identity/azure-identity/tests/test_obo.py | 49 +-- .../azure-identity/tests/test_obo_async.py | 50 ++-- .../tests/test_powershell_credential.py | 144 +++++---- .../tests/test_powershell_credential_async.py | 138 +++++---- .../tests/test_shared_cache_credential.py | 278 +++++++++++------- .../test_shared_cache_credential_async.py | 203 ++++++++----- .../test_username_password_credential.py | 54 ++-- .../tests/test_vscode_credential.py | 7 +- .../test_workload_identity_credential.py | 8 +- ...test_workload_identity_credential_async.py | 7 +- 108 files changed, 3999 insertions(+), 1645 deletions(-) create mode 100644 sdk/identity/azure-identity/tests/test_initialization.py create mode 100644 sdk/identity/azure-identity/tests/test_initialization_async.py diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 2f86649a7fc..6fae7686359 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,9 @@ ### Features Added +- All credentials now support the `SupportsTokenInfo` protocol. Each credential now has a `get_token_info` method which returns an `AccessTokenInfo` object. The `get_token_info` method is an alternative method to `get_token` that improves support support for more complex authentication scenarios. ([#36882](https://github.com/Azure/azure-sdk-for-python/pull/36882)) + - Information on when a token should be refreshed is now saved in `AccessTokenInfo` (if available). + ### Breaking Changes ### Bugs Fixed @@ -12,6 +15,7 @@ - Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950)) - Additional validation was added for `ManagedIdentityCredential` in Azure Cloud Shell environments. ([#36438](https://github.com/Azure/azure-sdk-for-python/issues/36438)) +- Bumped minimum dependency on `azure-core` to `>=1.31.0`. ## 1.18.0b2 (2024-08-09) diff --git a/sdk/identity/azure-identity/assets.json b/sdk/identity/azure-identity/assets.json index be62aaed10a..477825b9d64 100644 --- a/sdk/identity/azure-identity/assets.json +++ b/sdk/identity/azure-identity/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/identity/azure-identity", - "Tag": "python/identity/azure-identity_cb8dd6f319" + "Tag": "python/identity/azure-identity_61e626a4a0" } diff --git a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py index 209f46d46ef..3617f56eab3 100644 --- a/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py +++ b/sdk/identity/azure-identity/azure/identity/_bearer_token_provider.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Callable -from azure.core.credentials import TokenCredential +from azure.core.credentials import TokenProvider from azure.core.pipeline.policies import BearerTokenCredentialPolicy from azure.core.pipeline import PipelineRequest, PipelineContext from azure.core.rest import HttpRequest @@ -14,7 +14,7 @@ def _make_request() -> PipelineRequest[HttpRequest]: return PipelineRequest(HttpRequest("CredentialWrapper", "https://fakeurl"), PipelineContext(None)) -def get_bearer_token_provider(credential: TokenCredential, *scopes: str) -> Callable[[], str]: +def get_bearer_token_provider(credential: TokenProvider, *scopes: str) -> Callable[[], str]: """Returns a callable that provides a bearer token. It can be used for instance to write code like: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/_credentials/application.py index 852b44e7f71..81a06fd989d 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/application.py @@ -4,9 +4,9 @@ # ------------------------------------ import logging import os -from typing import Any, Optional +from typing import Any, Optional, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo, TokenCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -83,10 +83,37 @@ class AzureApplicationCredential(ChainedTokenCredential): `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = cast(TokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) return token return super(AzureApplicationCredential, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = cast(SupportsTokenInfo, self._successful_credential).get_token_info(*scopes, options=options) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + return cast(SupportsTokenInfo, super()).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index b5997e60e80..c98aa26c9a0 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .._internal.aad_client import AadClient from .._internal.get_token_mixin import GetTokenMixin @@ -90,10 +90,35 @@ class AuthorizationCodeCredential(GetTokenMixin): *scopes, claims=claims, tenant_id=tenant_id, client_secret=self._client_secret, **kwargs ) - def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + The first time this method is called, the credential will redeem its authorization code. On subsequent calls + the credential will return a cached access token or redeem a refresh token, if it acquired a refresh token upon + redeeming the authorization code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return super()._get_token_base( + *scopes, options=options, client_secret=self._client_secret, base_method_name="get_token_info" + ) + + def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: if self._authorization_code: token = self._client.obtain_token_by_authorization_code( scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py index 2af899ae2e6..319569482ab 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azd_cli.py @@ -12,7 +12,7 @@ import subprocess import sys from typing import Any, Dict, List, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -118,10 +118,43 @@ class AzureDeveloperCliCredential: :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure Developer CLI but didn't receive an access token. """ + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke + the Azure Developer CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked + the Azure Developer CLI but didn't receive an access token. + """ + return self._get_token_base(*scopes, options=options) + + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: if not scopes: raise ValueError("Missing scope in request. \n") + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -154,7 +187,7 @@ class AzureDeveloperCliCredential: return token -def parse_token(output: str) -> Optional[AccessToken]: +def parse_token(output: str) -> Optional[AccessTokenInfo]: """Parse to an AccessToken. In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by @@ -169,7 +202,7 @@ def parse_token(output: str) -> Optional[AccessToken]: dt = datetime.strptime(token["expiresOn"], "%Y-%m-%dT%H:%M:%SZ") expires_on = dt.timestamp() - return AccessToken(token["token"], int(expires_on)) + return AccessTokenInfo(token["token"], int(expires_on)) except (KeyError, ValueError): return None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index cb48b98356d..a6feae9d2d9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -11,7 +11,7 @@ import subprocess import sys from typing import List, Optional, Any, Dict -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -94,6 +94,41 @@ class AzureCliCredential: :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't receive an access token. """ + + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't + receive an access token. + """ + return self._get_token_base(*scopes, options=options) + + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -126,7 +161,7 @@ class AzureCliCredential: return token -def parse_token(output) -> Optional[AccessToken]: +def parse_token(output) -> Optional[AccessTokenInfo]: """Parse output of 'az account get-access-token' to an AccessToken. In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by @@ -141,11 +176,11 @@ def parse_token(output) -> Optional[AccessToken]: # Use "expires_on" if it's present, otherwise use "expiresOn". if "expires_on" in token: - return AccessToken(token["accessToken"], int(token["expires_on"])) + return AccessTokenInfo(token["accessToken"], int(token["expires_on"])) dt = datetime.strptime(token["expiresOn"], "%Y-%m-%d %H:%M:%S.%f") expires_on = dt.timestamp() - return AccessToken(token["accessToken"], int(expires_on)) + return AccessTokenInfo(token["accessToken"], int(expires_on)) except (KeyError, ValueError): return None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py index ea074b406cc..a981ecff7a4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_pipelines.py @@ -7,7 +7,7 @@ import os from typing import Any, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.rest import HttpRequest, HttpResponse from .client_assertion import ClientAssertionCredential @@ -125,6 +125,25 @@ class AzurePipelinesCredential: *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + validate_env_vars() + return self._client_assertion_credential.get_token_info(*scopes, options=options) + def _get_oidc_token(self) -> str: request = build_oidc_request(self._service_connection_id, self._system_access_token) response = self._pipeline.run(request, retry_on_methods=[request.method]) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index da3dd2c45ab..92dd0432bce 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -8,7 +8,7 @@ import subprocess import sys from typing import Any, List, Tuple, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .azure_cli import get_safe_working_dir @@ -125,6 +125,42 @@ class AzurePowerShellCredential: :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't receive an access token """ + + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. TThis credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or + no account is authenticated + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't + receive an access token + """ + return self._get_token_base(*scopes, options=options) + + def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: @@ -185,11 +221,11 @@ def start_process(args: List[str]) -> "subprocess.Popen": return proc -def parse_token(output: str) -> AccessToken: +def parse_token(output: str) -> AccessTokenInfo: for line in output.split(): if line.startswith("azsdk%"): _, token, expires_on = line.split("%") - return AccessToken(token, int(expires_on)) + return AccessTokenInfo(token, int(expires_on)) if within_dac.get(): raise CredentialUnavailableError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 10e03a6daa0..31f38937f1d 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -3,16 +3,20 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenRequestOptions, + SupportsTokenInfo, + TokenCredential, + TokenProvider, +) from .. import CredentialUnavailableError from .._internal import within_credential_chain -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - _LOGGER = logging.getLogger(__name__) @@ -48,12 +52,11 @@ class ChainedTokenCredential: :caption: Create a ChainedTokenCredential. """ - def __init__(self, *credentials): - # type: (*TokenCredential) -> None + def __init__(self, *credentials: TokenProvider) -> None: if not credentials: raise ValueError("at least one credential is required") - self._successful_credential = None # type: Optional[TokenCredential] + self._successful_credential: Optional[TokenProvider] = None self.credentials = credentials def __enter__(self) -> "ChainedTokenCredential": @@ -70,10 +73,18 @@ class ChainedTokenCredential: self.__exit__() def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, ) -> AccessToken: """Request a token from each chained credential, in order, returning the first token received. + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + This method is called automatically by Azure SDK clients. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -82,20 +93,38 @@ class ChainedTokenCredential: :keyword str claims: additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure. :keyword str tenant_id: optional tenant to include in the token request. + :keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. :return: An access token with the desired scopes. :rtype: ~azure.core.credentials.AccessToken :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token """ + within_credential_chain.set(True) history = [] for credential in self.credentials: try: - token = credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + # Prioritize "get_token". Fall back to "get_token_info" if not available. + if hasattr(credential, "get_token"): + token = cast(TokenCredential, credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + else: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) + token = AccessToken(token_info.token, token_info.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) return token + except CredentialUnavailableError as ex: # credential didn't attempt authentication because it lacks required data or state -> continue history.append((credential, ex.message)) @@ -110,6 +139,71 @@ class ChainedTokenCredential: exc_info=True, ) break + within_credential_chain.set(False) + attempts = _get_error_message(history) + message = ( + self.__class__.__name__ + + " failed to retrieve a token from the included credentials." + + attempts + + "\nTo mitigate this issue, please refer to the troubleshooting guidelines here at " + "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." + ) + _LOGGER.warning(message) + raise ClientAuthenticationError(message=message) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request a token from each chained credential, in order, returning the first token received. + + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token. + """ + within_credential_chain.set(True) + history = [] + options = options or {} + for credential in self.credentials: + try: + # Prioritize "get_token_info". Fall back to "get_token" if not available. + if hasattr(credential, "get_token_info"): + token_info = cast(SupportsTokenInfo, credential).get_token_info(*scopes, options=options) + else: + if options.get("pop"): + raise CredentialUnavailableError( + "Proof of possession arguments are not supported for this credential." + ) + token = cast(TokenCredential, credential).get_token(*scopes, **options) + token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) + self._successful_credential = credential + within_credential_chain.set(False) + return token_info + except CredentialUnavailableError as ex: + # credential didn't attempt authentication because it lacks required data or state -> continue + history.append((credential, ex.message)) + except Exception as ex: # pylint: disable=broad-except + # credential failed to authenticate, or something unexpectedly raised -> break + history.append((credential, str(ex))) + _LOGGER.debug( + '%s.get_token_info failed: %s raised unexpected error "%s"', + self.__class__.__name__, + credential.__class__.__name__, + ex, + exc_info=True, + ) + break within_credential_chain.set(False) attempts = _get_error_message(history) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 9970a2fb80e..bb371381c6b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Callable, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient from .._internal.get_token_mixin import GetTokenMixin @@ -68,10 +68,10 @@ class ClientAssertionCredential(GetTokenMixin): def close(self) -> None: self.__exit__() - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: assertion = self._func() token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 57446c062ca..035efb52bd3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -4,9 +4,9 @@ # ------------------------------------ import logging import os -from typing import List, TYPE_CHECKING, Any, Optional, cast +from typing import List, Any, Optional, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo, TokenCredential from .._constants import EnvironmentVariables from .._internal import get_default_authority, normalize_authority, within_dac from .azure_powershell import AzurePowerShellCredential @@ -20,9 +20,6 @@ from .azd_cli import AzureDeveloperCliCredential from .vscode import VisualStudioCodeCredential from .workload_identity import WorkloadIdentityCredential -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - _LOGGER = logging.getLogger(__name__) @@ -144,7 +141,7 @@ class DefaultAzureCredential(ChainedTokenCredential): exclude_interactive_browser_credential = kwargs.pop("exclude_interactive_browser_credential", True) exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False) - credentials: List["TokenCredential"] = [] + credentials: List[SupportsTokenInfo] = [] within_dac.set(True) if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs)) @@ -214,10 +211,12 @@ class DefaultAzureCredential(ChainedTokenCredential): :rtype: ~azure.core.credentials.AccessToken :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a - `message` attribute listing each authentication attempt and its error message. + `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = cast(TokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) @@ -226,3 +225,32 @@ class DefaultAzureCredential(ChainedTokenCredential): token = super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) within_dac.set(False) return token + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = cast(SupportsTokenInfo, self._successful_credential).get_token_info(*scopes, options=options) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + within_dac.set(True) + token_info = cast(SupportsTokenInfo, super()).get_token_info(*scopes, options=options) + within_dac.set(False) + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 146d9be5c9e..8f87a1d9ff9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -4,8 +4,8 @@ # ------------------------------------ import logging import os -from typing import Optional, Union, Any -from azure.core.credentials import AccessToken +from typing import Optional, Union, Any, cast +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, SupportsTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables @@ -155,3 +155,29 @@ class EnvironmentCredential: ) raise CredentialUnavailableError(message=message) return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete. + """ + if not self._credential: + message = ( + "EnvironmentCredential authentication unavailable. Environment variables are not fully configured.\n" + "Visit https://aka.ms/azsdk/python/identity/environmentcredential/troubleshoot to troubleshoot " + "this issue." + ) + raise CredentialUnavailableError(message=message) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index cd9a0149afc..6528f23b83e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -8,7 +8,7 @@ from typing import Any, Optional, Dict from azure.core.exceptions import ClientAuthenticationError, HttpResponseError from azure.core.pipeline.transport import HttpRequest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables @@ -76,7 +76,7 @@ class ImdsCredential(MsalManagedIdentityClient): def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if within_credential_chain.get() and not self._endpoint_available: # If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint @@ -96,7 +96,7 @@ class ImdsCredential(MsalManagedIdentityClient): raise CredentialUnavailableError(error_message) from ex try: - token = super()._request_token(*scopes) + token_info = super()._request_token(*scopes) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise @@ -123,7 +123,7 @@ class ImdsCredential(MsalManagedIdentityClient): # if anything else was raised, assume the endpoint is unavailable error_message = "ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint." raise CredentialUnavailableError(error_message) from ex - return token + return token_info def get_unavailable_message(self, desc: str = "") -> str: return f"ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint. {desc}" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 8c8fc9012cf..db8667b1d51 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -4,15 +4,13 @@ # ------------------------------------ import logging import os -from typing import Optional, TYPE_CHECKING, Any, Mapping +from typing import Optional, Any, Mapping, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions, TokenCredential, SupportsTokenInfo from .. import CredentialUnavailableError from .._constants import EnvironmentVariables from .._internal.decorators import log_get_token -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential _LOGGER = logging.getLogger(__name__) @@ -62,7 +60,7 @@ class ManagedIdentityCredential: self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any ) -> None: validate_identity_config(client_id, identity_config) - self._credential: Optional[TokenCredential] = None + self._credential: Optional[SupportsTokenInfo] = None exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): if os.environ.get(EnvironmentVariables.IDENTITY_HEADER): @@ -159,4 +157,29 @@ class ManagedIdentityCredential: "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " "troubleshoot this issue." ) - return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return cast(TokenCredential, self._credential).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment. + """ + if not self._credential: + raise CredentialUnavailableError( + message="No managed identity endpoint found. \n" + "The Target Azure platform could not be determined from environment variables. \n" + "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " + "troubleshoot this issue." + ) + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py index 9f8889bd44f..93675adb84e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Callable, Union, Dict import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .certificate import get_client_credential @@ -123,7 +123,7 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin): self._auth_record: Optional[AuthenticationRecord] = None @wrap_exceptions - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: if self._auth_record: claims = kwargs.get("claims") app = self._get_app(**kwargs) @@ -134,12 +134,15 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin): now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on + ) return None @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: app: msal.ConfidentialClientApplication = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_on_behalf_of(self._assertion, list(scopes), claims_challenge=kwargs.get("claims")) @@ -153,4 +156,5 @@ class OnBehalfOfCredential(MsalCredential, GetTokenMixin): except ClientAuthenticationError: pass # non-fatal; we'll use the assertion again next time instead of a refresh token - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo(result["access_token"], request_time + int(result["expires_in"]), refresh_on=refresh_on) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 0aaaf11cab5..39e895c6997 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TYPE_CHECKING, Any, Optional, TypeVar -from azure.core.credentials import AccessToken +from typing import Any, Optional, TypeVar, cast +from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo, SupportsTokenInfo, TokenCredential from .silent import SilentAuthenticationCredential from .. import CredentialUnavailableError @@ -12,9 +12,6 @@ from .._internal import AadClient, AadClientBase from .._internal.decorators import log_get_token from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - T = TypeVar("T", bound="_SharedTokenCacheCredential") @@ -39,7 +36,7 @@ class SharedTokenCacheCredential: def __init__(self, username: Optional[str] = None, **kwargs: Any) -> None: if "authentication_record" in kwargs: - self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential + self._credential: SupportsTokenInfo = SilentAuthenticationCredential(**kwargs) else: self._credential = _SharedTokenCacheCredential(username=username, **kwargs) @@ -61,7 +58,7 @@ class SharedTokenCacheCredential: claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Get an access token for `scopes` from the shared cache. @@ -85,7 +82,32 @@ class SharedTokenCacheCredential: :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ - return self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) + return cast(TokenCredential, self._credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + + @log_get_token + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + If no access token is cached, attempt to acquire one using a cached refresh token. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user + information. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) @staticmethod def supported() -> bool: @@ -109,21 +131,48 @@ class _SharedTokenCacheCredential(SharedTokenCacheBase): if self._client: self._client.__exit__(*args) + def close(self) -> None: + self.__exit__() + def get_token( self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError("'get_token' requires at least one scope") + raise ValueError(f"'{base_method_name}' requires at least one scope") if not self._client_initialized: self._initialize_client() - is_cae = enable_cae + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + is_cae = options.get("enable_cae", False) + token_cache = self._cae_cache if is_cae else self._cache # Try to load the cache if it is None. @@ -142,8 +191,8 @@ class _SharedTokenCacheCredential(SharedTokenCacheBase): # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): - token = self._client.obtain_token_by_refresh_token( - scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs + token = cast(AadClient, self._client).obtain_token_by_refresh_token( + scopes, refresh_token, claims=claims, tenant_id=tenant_id, enable_cae=is_cae, **kwargs ) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index 7a86656c23d..80d170a629d 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -8,7 +8,7 @@ from typing import Dict, Optional, Any from msal import PublicClientApplication, TokenCache -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError @@ -58,17 +58,51 @@ class SilentAuthenticationCredential: def __exit__(self, *args): self._client.__exit__(*args) - def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any - ) -> AccessToken: - if not scopes: - raise ValueError('"get_token" requires at least one scope') + def close(self) -> None: + self.__exit__() - token_cache = self._cae_cache if kwargs.get("enable_cae") else self._cache + def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: + + if not scopes: + raise ValueError(f"'{base_method_name}' requires at least one scope") + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) + + token_cache = self._cae_cache if enable_cae else self._cache # Try to load the cache if it is None. if not token_cache: - token_cache = self._initialize_cache(is_cae=bool(kwargs.get("enable_cae"))) + token_cache = self._initialize_cache(is_cae=enable_cae) # If the cache is still None, raise an error. if not token_cache: @@ -76,7 +110,7 @@ class SilentAuthenticationCredential: raise CredentialUnavailableError(message="Shared token cache unavailable") raise ClientAuthenticationError(message="Shared token cache unavailable") - return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]: @@ -129,7 +163,7 @@ class SilentAuthenticationCredential: return client_applications_map[tenant_id] @wrap_exceptions - def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: """Silently acquire a token from MSAL. :param str scopes: desired scopes for the access token @@ -153,7 +187,7 @@ class SilentAuthenticationCredential: list(scopes), account=account, claims_challenge=kwargs.get("claims") ) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"])) # if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently if result: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 4a1ca8a26c2..4990d707629 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -7,7 +7,7 @@ import os import sys from typing import cast, Any, Dict, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, TokenRequestOptions, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .._exceptions import CredentialUnavailableError from .._constants import AzureAuthorityHosts, AZURE_VSCODE_CLIENT_ID, EnvironmentVariables @@ -174,11 +174,42 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, GetTokenMixin): raise CredentialUnavailableError(message=ex.message) from ex return super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual + Studio Code. + """ + if self._unavailable_reason: + error_message = ( + self._unavailable_reason + "\n" + "Visit https://aka.ms/azsdk/python/identity/vscodecredential/troubleshoot" + " to troubleshoot this issue." + ) + raise CredentialUnavailableError(message=error_message) + if within_dac.get(): + try: + token = super().get_token_info(*scopes, options=options) + return token + except ClientAuthenticationError as ex: + raise CredentialUnavailableError(message=ex.message) from ex + return super().get_token_info(*scopes, options=options) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index f12ff5618aa..02fb0c922f5 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -5,7 +5,7 @@ import time from typing import Iterable, Union, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import HttpRequest from .aad_client_base import AadClientBase @@ -26,7 +26,7 @@ class AadClient(AadClientBase): def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -34,19 +34,21 @@ class AadClient(AadClientBase): def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_client_certificate_request(scopes, certificate, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessTokenInfo: request = self._get_client_secret_request(scopes, secret, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessTokenInfo: request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return self._run_pipeline(request, **kwargs) - def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_token: str, **kwargs: Any) -> AccessToken: + def obtain_token_by_refresh_token( + self, scopes: Iterable[str], refresh_token: str, **kwargs: Any + ) -> AccessTokenInfo: request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return self._run_pipeline(request, **kwargs) @@ -56,14 +58,14 @@ class AadClient(AadClientBase): client_credential: Union[str, AadClientCertificate], user_assertion: str, **kwargs: Any - ) -> AccessToken: + ) -> AccessTokenInfo: # no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL raise NotImplementedError() def _build_pipeline(self, **kwargs: Any) -> Pipeline: return build_pipeline(**kwargs) - def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken: + def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index c41f87d2361..209a4bf8a79 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -7,14 +7,14 @@ import base64 import json import time from uuid import uuid4 -from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict +from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast from msal import TokenCache from azure.core.pipeline import PipelineResponse from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.pipeline.transport import HttpRequest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .utils import get_default_authority, normalize_authority, resolve_tenant from .aadclient_certificate import AadClientCertificate @@ -79,9 +79,9 @@ class AadClientBase(abc.ABC): self._cae_cache = TokenCache() else: self._cache = TokenCache() - return self._cae_cache if is_cae else self._cache + return cast(TokenCache, self._cae_cache if is_cae else self._cache) - def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessToken]: + def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessTokenInfo]: tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs ) @@ -94,7 +94,8 @@ class AadClientBase(abc.ABC): ): expires_on = int(token["expires_on"]) if expires_on > int(time.time()): - return AccessToken(token["secret"], expires_on) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) return None def get_cached_refresh_tokens(self, scopes: Iterable[str], **kwargs) -> List[Dict]: @@ -130,7 +131,7 @@ class AadClientBase(abc.ABC): def _build_pipeline(self, **kwargs): pass - def _process_response(self, response: PipelineResponse, request_time: int, **kwargs) -> AccessToken: + def _process_response(self, response: PipelineResponse, request_time: int, **kwargs) -> AccessTokenInfo: content = response.context.get( ContentDecodePolicy.CONTEXT_NAME ) or ContentDecodePolicy.deserialize_from_http_generics(response.http_response) @@ -171,7 +172,13 @@ class AadClientBase(abc.ABC): _scrub_secrets(content) raise ClientAuthenticationError(message="Unexpected response from Microsoft Entra ID: {}".format(content)) - token = AccessToken(content["access_token"], expires_on) + expires_in = int(content.get("expires_in") or expires_on - request_time) + if "refresh_in" not in content and expires_in >= 7200: + # MSAL TokenCache expects "refresh_in" + content["refresh_in"] = expires_in // 2 + + refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None + token = AccessTokenInfo(content["access_token"], expires_on, refresh_on=refresh_on) # caching is the final step because 'add' mutates 'content' cache.add( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index 81579e3bdec..7685e8522c3 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -5,7 +5,7 @@ import time from typing import Any, Optional, Dict -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from .get_token_mixin import GetTokenMixin @@ -23,18 +23,23 @@ class ClientCredentialBase(MsalCredential, GetTokenMixin): """Base class for credentials authenticating a service principal with a certificate or secret""" @wrap_exceptions - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_silent_with_error( list(scopes), account=None, claims_challenge=kwargs.pop("claims", None), **_get_known_kwargs(kwargs) ) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + request_time + int(result["expires_in"]), + refresh_on=refresh_on, + ) return None @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_for_client(list(scopes), claims_challenge=kwargs.pop("claims", None)) @@ -42,4 +47,9 @@ class ClientCredentialBase(MsalCredential, GetTokenMixin): message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) - return AccessToken(result["access_token"], request_time + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + request_time + int(result["expires_in"]), + refresh_on=refresh_on, + ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index 022555b5998..16fd1ea9f9c 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -7,7 +7,7 @@ import logging import time from typing import Any, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .utils import within_credential_chain from .._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY @@ -22,7 +22,7 @@ class GetTokenMixin(abc.ABC): super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: """Attempt to acquire an access token from a cache or by redeeming a refresh token. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -30,11 +30,11 @@ class GetTokenMixin(abc.ABC): https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes if successful; otherwise, None. - :rtype: ~azure.core.credentials.AccessToken or None + :rtype: ~azure.core.credentials.AccessTokenInfo or None """ @abc.abstractmethod - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: """Request an access token from the STS. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -42,11 +42,13 @@ class GetTokenMixin(abc.ABC): https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes. - :rtype: ~azure.core.credentials.AccessToken + :rtype: ~azure.core.credentials.AccessTokenInfo """ - def _should_refresh(self, token: AccessToken) -> bool: + def _should_refresh(self, token: AccessTokenInfo) -> bool: now = int(time.time()) + if token.refresh_on is not None and now >= token.refresh_on: + return True if token.expires_on - now > DEFAULT_REFRESH_OFFSET: return False if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: @@ -59,7 +61,7 @@ class GetTokenMixin(abc.ABC): claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -81,8 +83,50 @@ class GetTokenMixin(abc.ABC): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) try: token = self._acquire_token_silently( @@ -103,16 +147,18 @@ class GetTokenMixin(abc.ABC): pass _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token except Exception as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index d3e671e3069..c2665ee1593 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -9,10 +9,10 @@ import base64 import json import logging import time -from typing import Any, Optional, Iterable +from typing import Any, Optional, Iterable, Dict from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .msal_credentials import MsalCredential @@ -95,7 +95,7 @@ class InteractiveCredential(MsalCredential, ABC): *, authentication_record: Optional[AuthenticationRecord] = None, disable_automatic_authentication: bool = False, - **kwargs: Any + **kwargs: Any, ) -> None: self._disable_automatic_authentication = disable_automatic_authentication self._auth_record = authentication_record @@ -106,7 +106,7 @@ class InteractiveCredential(MsalCredential, ABC): client_id=self._auth_record.client_id, authority=self._auth_record.authority, tenant_id=tenant_id, - **kwargs + **kwargs, ) else: super(InteractiveCredential, self).__init__(**kwargs) @@ -117,7 +117,7 @@ class InteractiveCredential(MsalCredential, ABC): claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -140,23 +140,74 @@ class InteractiveCredential(MsalCredential, ABC): :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - message = "'get_token' requires at least one scope" - _LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message) + message = f"'{base_method_name}' requires at least one scope" + _LOGGER.warning("%s.%s failed: %s", self.__class__.__name__, base_method_name, message) raise ValueError(message) allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) + + # Check for arbitrary additional options to enable intermediary support for PoP tokens. + for key in options: + if key not in TokenRequestOptions.__annotations__: # pylint:disable=no-member + kwargs.setdefault(key, options[key]) # type: ignore + try: token = self._acquire_token_silent( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) + _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) return token except Exception as ex: # pylint:disable=broad-except if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): _LOGGER.warning( - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) @@ -176,15 +227,22 @@ class InteractiveCredential(MsalCredential, ABC): self._auth_record = _build_auth_record(result) except Exception as ex: # pylint:disable=broad-except _LOGGER.warning( - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise - _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) - return AccessToken(result["access_token"], now + int(result["expires_in"])) + _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], + now + int(result["expires_in"]), + token_type=result.get("token_type", "Bearer"), + refresh_on=refresh_on, + ) def authenticate( self, *, scopes: Optional[Iterable[str]] = None, claims: Optional[str] = None, **kwargs: Any @@ -214,7 +272,7 @@ class InteractiveCredential(MsalCredential, ABC): return self._auth_record # type: ignore @wrap_exceptions - def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: result = None claims = kwargs.get("claims") if self._auth_record: @@ -226,7 +284,10 @@ class InteractiveCredential(MsalCredential, ABC): now = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo( + result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on + ) # if we get this far, result is either None or the content of a Microsoft Entra ID error response if result: @@ -235,5 +296,5 @@ class InteractiveCredential(MsalCredential, ABC): raise AuthenticationRequiredError(scopes, claims=claims) @abc.abstractmethod - def _request_token(self, *scopes, **kwargs): + def _request_token(self, *scopes, **kwargs) -> Dict: pass diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py index 540e52b4f71..949ac14a844 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -5,7 +5,7 @@ import abc from typing import cast, Any, Optional, TypeVar -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .. import CredentialUnavailableError from .._internal.managed_identity_client import ManagedIdentityClient from .._internal.get_token_mixin import GetTokenMixin @@ -47,10 +47,15 @@ class ManagedIdentityBase(GetTokenMixin): raise CredentialUnavailableError(message=self.get_unavailable_message()) return super(ManagedIdentityBase, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return super().get_token_info(*scopes, options=options) + + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 1b867cc8440..58b5e29b987 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional from msal import TokenCache -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError, DecodeError from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.pipeline import PipelineResponse @@ -40,7 +40,7 @@ class ManagedIdentityClientBase(abc.ABC): self._pipeline = self._build_pipeline(**kwargs) self._request_factory = request_factory - def _process_response(self, response: PipelineResponse, request_time: int) -> AccessToken: + def _process_response(self, response: PipelineResponse, request_time: int) -> AccessTokenInfo: content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) if not content: try: @@ -70,7 +70,13 @@ class ManagedIdentityClientBase(abc.ABC): expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time) content["expires_on"] = expires_on - token = AccessToken(content["access_token"], content["expires_on"]) + expires_in = int(content.get("expires_in") or expires_on - request_time) + if "refresh_in" not in content and expires_in >= 7200: + # MSAL TokenCache expects "refresh_in" + content["refresh_in"] = expires_in // 2 + + refresh_on = request_time + int(content["refresh_in"]) if "refresh_in" in content else None + token = AccessTokenInfo(content["access_token"], content["expires_on"], refresh_on=refresh_on) # caching is the final step because TokenCache.add mutates its "event" self._cache.add( @@ -80,15 +86,15 @@ class ManagedIdentityClientBase(abc.ABC): return token - def get_cached_token(self, *scopes: str) -> Optional[AccessToken]: + def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]: resource = _scopes_to_resource(*scopes) - for token in self._cache.search( - TokenCache.CredentialType.ACCESS_TOKEN, - target=[resource], - ): + now = time.time() + for token in self._cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]): expires_on = int(token["expires_on"]) - if expires_on > time.time(): - return AccessToken(token["secret"], expires_on) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None + if expires_on > now and (not refresh_on or refresh_on > now): + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) + return None @abc.abstractmethod @@ -124,7 +130,7 @@ class ManagedIdentityClient(ManagedIdentityClientBase): def close(self) -> None: self.__exit__() - def request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + def request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: resource = _scopes_to_resource(*scopes) request = self._request_factory(resource, self._identity_config) kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py index cee1acc0345..1c768ec0924 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py @@ -8,7 +8,7 @@ import time import logging import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from .msal_client import MsalClient @@ -45,14 +45,15 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument if not scopes: raise ValueError('"get_token" requires at least one scope') resource = _scopes_to_resource(*scopes) result = self._msal_client.acquire_token_for_client(resource=resource) now = int(time.time()) if result and "access_token" in result and "expires_in" in result: - return AccessToken(result["access_token"], now + int(result["expires_in"])) + refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None + return AccessTokenInfo(result["access_token"], now + int(result["expires_in"]), refresh_on=refresh_on) if result and "error" in result: error_desc = cast(str, result["error"]) error_message = self.get_unavailable_message(error_desc) @@ -83,7 +84,7 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -105,31 +106,77 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') _scopes_to_resource(*scopes) token = None + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) + try: token = self._request_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) if token: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed", + "%s.%s failed", self.__class__.__name__, + base_method_name, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) raise CredentialUnavailableError(self.get_unavailable_message()) except msal.ManagedIdentityError as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) @@ -137,8 +184,9 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v except Exception as ex: # pylint:disable=broad-except _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 867b255a0c7..4afb4bbdef8 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -9,7 +9,7 @@ from typing import Any, Iterable, List, Mapping, Optional, cast, Dict from urllib.parse import urlparse import msal -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .. import CredentialUnavailableError from .._constants import KnownAuthorities from .._internal import get_default_authority, normalize_authority, wrap_exceptions @@ -228,7 +228,7 @@ class SharedTokenCacheBase(ABC): # pylint: disable=too-many-instance-attributes def _get_cached_access_token( self, scopes: Iterable[str], account: CacheItem, is_cae: bool = False - ) -> Optional[AccessToken]: + ) -> Optional[AccessTokenInfo]: if "home_account_id" not in account: return None @@ -241,8 +241,9 @@ class SharedTokenCacheBase(ABC): # pylint: disable=too-many-instance-attributes ) for token in cache_entries: expires_on = int(token["expires_on"]) + refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None if expires_on - 300 > int(time.time()): - return AccessToken(token["secret"], expires_on) + return AccessTokenInfo(token["secret"], expires_on, refresh_on=refresh_on) except Exception as ex: # pylint:disable=broad-except message = "Error accessing cached data: {}".format(ex) raise CredentialUnavailableError(message=message) from ex diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py index 8a795b4afc2..980171ef1e6 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import Optional, Any +from typing import Optional, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential from .managed_identity import ManagedIdentityCredential @@ -82,10 +83,39 @@ class AzureApplicationCredential(ChainedTokenCredential): `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - token = await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = await cast(AsyncTokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) _LOGGER.info( "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ ) return token return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = await cast(AsyncSupportsTokenInfo, self._successful_credential).get_token_info( + *scopes, options=options + ) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + return await cast(AsyncSupportsTokenInfo, super()).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index aa2ed7d6d87..c075cd30055 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -5,7 +5,7 @@ from typing import Optional, Any, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin @@ -96,10 +96,35 @@ class AuthorizationCodeCredential(AsyncContextManager, GetTokenMixin): *scopes, claims=claims, tenant_id=tenant_id, client_secret=self._client_secret, **kwargs ) - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + The first time this method is called, the credential will redeem its authorization code. On subsequent calls + the credential will return a cached access token or redeem a refresh token, if it acquired a refresh token upon + redeeming the authorization code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return await super()._get_token_base( + *scopes, options=options, client_secret=self._client_secret, base_method_name="get_token_info" + ) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if self._authorization_code: token = await self._client.obtain_token_by_authorization_code( scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs @@ -107,7 +132,7 @@ class AuthorizationCodeCredential(AsyncContextManager, GetTokenMixin): self._authorization_code = None # auth codes are single-use return token - token = cast(AccessToken, None) + token = cast(AccessTokenInfo, None) for refresh_token in self._client.get_cached_refresh_tokens(scopes): if "secret" in refresh_token: token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py index 19b1fda2c99..eafeb5affd4 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azd_cli.py @@ -9,7 +9,7 @@ import sys from typing import Any, List, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError @@ -108,9 +108,46 @@ class AzureDeveloperCliCredential(AsyncContextManager): if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncAzureDeveloperCliCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke + the Azure Developer CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked + the Azure Developer CLI but didn't receive an access token. + """ + # only ProactorEventLoop supports subprocesses on Windows (and it isn't the default loop on Python < 3.8) + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncAzureDeveloperCliCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: if not scopes: raise ValueError("Missing scope in request. \n") + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py index dec0fcc690b..62f4f23e478 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py @@ -9,7 +9,7 @@ import sys from typing import Any, List, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError @@ -89,6 +89,42 @@ class AzureCliCredential(AsyncContextManager): if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncAzureCliCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke the Azure CLI. + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked the Azure CLI but didn't + receive an access token. + """ + # only ProactorEventLoop supports subprocesses on Windows (and it isn't the default loop on Python < 3.8) + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncAzureCliCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py index 918cae86a92..ccaa635f1da 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_pipelines.py @@ -5,7 +5,7 @@ from typing import Any, Optional from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.rest import HttpResponse from .client_assertion import ClientAssertionCredential @@ -103,6 +103,25 @@ class AzurePipelinesCredential(AsyncContextManager): *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + validate_env_vars() + return await self._client_assertion_credential.get_token_info(*scopes, options=options) + def _get_oidc_token(self) -> str: request = build_oidc_request(self._service_connection_id, self._system_access_token) response = self._pipeline.run(request, retry_on_methods=[request.method]) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py index 6f7e0ffc5fa..f9117f2e66c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py @@ -5,7 +5,7 @@ import asyncio import sys from typing import Any, cast, List, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async @@ -84,6 +84,42 @@ class AzurePowerShellCredential(AsyncContextManager): if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): return _SyncCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs) + options: TokenRequestOptions = {} + if tenant_id: + options["tenant_id"] = tenant_id + + token_info = await self._get_token_base(*scopes, options=options, **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. Applications calling this method + directly must also handle token caching because this credential doesn't cache the tokens it acquires. + + :param str scopes: desired scopes for the access token. TThis credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: the credential was unable to invoke Azure PowerShell, or + no account is authenticated + :raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't + receive an access token + """ + if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop): + return _SyncCredential().get_token_info(*scopes, options=options) + return await self._get_token_base(*scopes, options=options) + + async def _get_token_base( + self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any + ) -> AccessTokenInfo: + tenant_id = options.get("tenant_id") if options else None if tenant_id: validate_tenant_id(tenant_id) for scope in scopes: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 67a2fefbc3b..8d2ad9ae12f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._credentials.certificate import get_client_credential @@ -70,8 +70,8 @@ class CertificateCredential(AsyncContextManager, GetTokenMixin): await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index e6cb50744fd..16ce23709dc 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -4,18 +4,16 @@ # ------------------------------------ import asyncio import logging -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, cast from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from .._internal import AsyncContextManager from ... import CredentialUnavailableError from ..._credentials.chained import _get_error_message from ..._internal import within_credential_chain -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential - _LOGGER = logging.getLogger(__name__) @@ -38,11 +36,11 @@ class ChainedTokenCredential(AsyncContextManager): :caption: Create a ChainedTokenCredential. """ - def __init__(self, *credentials: "AsyncTokenCredential") -> None: + def __init__(self, *credentials: AsyncTokenProvider) -> None: if not credentials: raise ValueError("at least one credential is required") - self._successful_credential = None # type: Optional[AsyncTokenCredential] + self._successful_credential: Optional[AsyncTokenProvider] = None self.credentials = credentials async def close(self) -> None: @@ -51,7 +49,12 @@ class ChainedTokenCredential(AsyncContextManager): await asyncio.gather(*(credential.close() for credential in self.credentials)) async def get_token( - self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, ) -> AccessToken: """Asynchronously request a token from each credential, in order, returning the first token received. @@ -66,6 +69,8 @@ class ChainedTokenCredential(AsyncContextManager): :keyword str claims: additional claims required in the token, such as those returned in a resource provider's claims challenge following an authorization failure. :keyword str tenant_id: optional tenant to include in the token request. + :keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested + token. Defaults to False. :return: An access token with the desired scopes. :rtype: ~azure.core.credentials.AccessToken @@ -75,7 +80,21 @@ class ChainedTokenCredential(AsyncContextManager): history = [] for credential in self.credentials: try: - token = await credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + # Prioritize "get_token". Fall back to "get_token_info" if not available. + if hasattr(credential, "get_token"): + token = await cast(AsyncTokenCredential, credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + else: + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, **kwargs) + token = AccessToken(token_info.token, token_info.expires_on) + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) self._successful_credential = credential within_credential_chain.set(False) @@ -94,7 +113,6 @@ class ChainedTokenCredential(AsyncContextManager): exc_info=True, ) break - within_credential_chain.set(False) attempts = _get_error_message(history) message = ( @@ -105,3 +123,69 @@ class ChainedTokenCredential(AsyncContextManager): "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." ) raise ClientAuthenticationError(message=message) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request a token from each chained credential, in order, returning the first token received. + + If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` + with an error message from each credential. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: no credential in the chain provided a token. + """ + within_credential_chain.set(True) + history = [] + options = options or {} + for credential in self.credentials: + try: + # Prioritize "get_token_info". Fall back to "get_token" if not available. + if hasattr(credential, "get_token_info"): + token_info = await cast(AsyncSupportsTokenInfo, credential).get_token_info(*scopes, options=options) + else: + if options.get("pop"): + raise CredentialUnavailableError( + "Proof of possession arguments are not supported for this credential." + ) + token = await cast(AsyncTokenCredential, credential).get_token(*scopes, **options) + token_info = AccessTokenInfo(token=token.token, expires_on=token.expires_on) + + _LOGGER.info("%s acquired a token from %s", self.__class__.__name__, credential.__class__.__name__) + self._successful_credential = credential + within_credential_chain.set(False) + return token_info + except CredentialUnavailableError as ex: + # credential didn't attempt authentication because it lacks required data or state -> continue + history.append((credential, ex.message)) + except Exception as ex: # pylint: disable=broad-except + # credential failed to authenticate, or something unexpectedly raised -> break + history.append((credential, str(ex))) + _LOGGER.debug( + '%s.get_token_info failed: %s raised unexpected error "%s"', + self.__class__.__name__, + credential.__class__.__name__, + ex, + exc_info=True, + ) + break + + within_credential_chain.set(False) + attempts = _get_error_message(history) + message = ( + self.__class__.__name__ + + " failed to retrieve a token from the included credentials." + + attempts + + "\nTo mitigate this issue, please refer to the troubleshooting guidelines here at " + "https://aka.ms/azsdk/python/identity/defaultazurecredential/troubleshoot." + ) + _LOGGER.warning(message) + raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py index 64150cfdcd4..a316760455e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Any, Callable, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin @@ -66,10 +66,10 @@ class ClientAssertionCredential(AsyncContextManager, GetTokenMixin): """Close the credential's transport session.""" await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: assertion = self._func() token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index f5495b17da6..50bbb3de931 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._internal import validate_tenant_id @@ -60,8 +60,8 @@ class ClientSecretCredential(AsyncContextManager, GetTokenMixin): await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index bd8672a2c7c..fe1bc03b608 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import List, Optional, TYPE_CHECKING, Any, cast +from typing import List, Optional, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from ..._constants import EnvironmentVariables from ..._internal import get_default_authority, normalize_authority, within_dac from .azure_cli import AzureCliCredential @@ -19,8 +20,6 @@ from .shared_cache import SharedTokenCacheCredential from .vscode import VisualStudioCodeCredential from .workload_identity import WorkloadIdentityCredential -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) @@ -134,7 +133,7 @@ class DefaultAzureCredential(ChainedTokenCredential): exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False) exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False) - credentials = [] # type: List[AsyncTokenCredential] + credentials: List[AsyncSupportsTokenInfo] = [] within_dac.set(True) if not exclude_environment_credential: credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs)) @@ -197,8 +196,46 @@ class DefaultAzureCredential(ChainedTokenCredential): `message` attribute listing each authentication attempt and its error message. """ if self._successful_credential: - return await self._successful_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + token = await cast(AsyncTokenCredential, self._successful_credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token + within_dac.set(True) token = await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) within_dac.set(False) return token + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Asynchronously request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a + `message` attribute listing each authentication attempt and its error message. + """ + if self._successful_credential: + token_info = await cast(AsyncSupportsTokenInfo, self._successful_credential).get_token_info( + *scopes, options=options + ) + _LOGGER.info( + "%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__ + ) + return token_info + + within_dac.set(True) + token_info = await cast(AsyncSupportsTokenInfo, super()).get_token_info(*scopes, options=options) + within_dac.set(False) + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index 146c91edd4b..bb981406d2a 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -4,9 +4,10 @@ # ------------------------------------ import logging import os -from typing import Optional, Union, Any +from typing import Optional, Union, Any, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables @@ -124,3 +125,29 @@ class EnvironmentCredential(AsyncContextManager): ) raise CredentialUnavailableError(message=message) return await self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + + :raises ~azure.identity.CredentialUnavailableError: environment variable configuration is incomplete. + """ + if not self._credential: + message = ( + "EnvironmentCredential authentication unavailable. Environment variables are not fully configured.\n" + "Visit https://aka.ms/azsdk/python/identity/environmentcredential/troubleshoot to troubleshoot " + "this issue." + ) + raise CredentialUnavailableError(message=message) + return await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 0f667be030e..f9286c9f88f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -6,7 +6,7 @@ import os from typing import Optional, Any from azure.core.exceptions import ClientAuthenticationError, HttpResponseError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from .._internal import AsyncContextManager @@ -34,10 +34,10 @@ class ImdsCredential(AsyncContextManager, GetTokenMixin): async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_token(*scopes) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument if within_credential_chain.get() and not self._endpoint_available: # If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint @@ -56,7 +56,7 @@ class ImdsCredential(AsyncContextManager, GetTokenMixin): raise CredentialUnavailableError(message=error_message) from ex try: - token = await self._client.request_token(*scopes, headers={"Metadata": "true"}) + token_info = await self._client.request_token(*scopes, headers={"Metadata": "true"}) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise @@ -82,4 +82,4 @@ class ImdsCredential(AsyncContextManager, GetTokenMixin): # if anything else was raised, assume the endpoint is unavailable error_message = "ManagedIdentityCredential authentication unavailable, no response from the IMDS endpoint." raise CredentialUnavailableError(error_message) from ex - return token + return token_info diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 0362886712e..22ab3f4a269 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -4,17 +4,16 @@ # ------------------------------------ import logging import os -from typing import TYPE_CHECKING, Optional, Any, Mapping +from typing import Optional, Any, Mapping, cast -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo from .._internal import AsyncContextManager from .._internal.decorators import log_get_token_async from ... import CredentialUnavailableError from ..._constants import EnvironmentVariables from ..._credentials.managed_identity import validate_identity_config -if TYPE_CHECKING: - from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) @@ -48,7 +47,7 @@ class ManagedIdentityCredential(AsyncContextManager): self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any ) -> None: validate_identity_config(client_id, identity_config) - self._credential: Optional[AsyncTokenCredential] = None + self._credential: Optional[AsyncSupportsTokenInfo] = None exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False) if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): @@ -141,4 +140,31 @@ class ManagedIdentityCredential(AsyncContextManager): "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " "troubleshoot this issue." ) - return await self._credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) + return await cast(AsyncTokenCredential, self._credential).get_token( + *scopes, claims=claims, tenant_id=tenant_id, **kwargs + ) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This credential allows only one scope per request. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: managed identity isn't available in the hosting environment. + """ + if not self._credential: + raise CredentialUnavailableError( + message="No managed identity endpoint found. \n" + "The Target Azure platform could not be determined from environment variables. \n" + "Visit https://aka.ms/azsdk/python/identity/managedidentitycredential/troubleshoot to " + "troubleshoot this issue." + ) + return await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py index d807db72c94..102db030f63 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py @@ -6,7 +6,7 @@ import logging from typing import Optional, Union, Any, Dict, Callable from azure.core.exceptions import ClientAuthenticationError -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from .._internal import AadClient, AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from ..._credentials.certificate import get_client_credential @@ -111,10 +111,10 @@ class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin): async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # Note we assume the cache has tokens for one user only. That's okay because each instance of this class is # locked to a single user (assertion). This assumption will become unsafe if this class allows applications # to change an instance's assertion. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 395ccb009a7..6d5f38d497f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any, Optional -from azure.core.credentials import AccessToken +from typing import Any, Optional, cast +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._internal.aad_client import AadClientBase from ... import CredentialUnavailableError from ..._constants import DEVELOPER_SIGN_ON_CLIENT_ID @@ -48,7 +48,7 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Get an access token for `scopes` from the shared cache. @@ -73,13 +73,58 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's ``response`` attribute. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = await self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + @log_get_token_async + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Get an access token for `scopes` from the shared cache. + + If no access token is cached, attempt to acquire one using a cached refresh token. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scope for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user + information + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. Any error response from Microsoft Entra ID is available as the error's + ``response`` attribute. + """ + return await self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + async def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError("'get_token' requires at least one scope") + raise ValueError(f"'{base_method_name}' requires at least one scope") if not self._client_initialized: self._initialize_client() - is_cae = enable_cae + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + is_cae = options.get("enable_cae", False) + token_cache = self._cae_cache if is_cae else self._cache # Try to load the cache if it is None. @@ -98,8 +143,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): - token = await self._client.obtain_token_by_refresh_token( - scopes, refresh_token, claims=claims, tenant_id=tenant_id, **kwargs + token = await cast(AadClient, self._client).obtain_token_by_refresh_token( + scopes, refresh_token, claims=claims, tenant_id=tenant_id, enable_cae=is_cae, **kwargs ) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index 35eda5eb547..9451cc45f01 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -4,7 +4,7 @@ # ------------------------------------ from typing import cast, Optional, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from azure.core.exceptions import ClientAuthenticationError from ..._exceptions import CredentialUnavailableError from .._internal import AsyncContextManager @@ -83,11 +83,42 @@ class VisualStudioCodeCredential(_VSCodeCredentialBase, AsyncContextManager, Get raise CredentialUnavailableError(message=ex.message) from ex return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises ~azure.identity.CredentialUnavailableError: the credential cannot retrieve user details from Visual + Studio Code. + """ + if self._unavailable_reason: + error_message = ( + self._unavailable_reason + "\n" + "Visit https://aka.ms/azsdk/python/identity/vscodecredential/troubleshoot" + " to troubleshoot this issue." + ) + raise CredentialUnavailableError(message=error_message) + if within_dac.get(): + try: + token = await super().get_token_info(*scopes, options=options) + return token + except ClientAuthenticationError as ex: + raise CredentialUnavailableError(message=ex.message) from ex + return await super().get_token_info(*scopes, options=options) + + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 1a6bb03cc19..7b99f85ac91 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -5,7 +5,7 @@ import time from typing import Iterable, Optional, Union, Dict, Any -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy from azure.core.pipeline.transport import HttpRequest @@ -32,7 +32,7 @@ class AadClient(AadClientBase): async def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -40,19 +40,21 @@ class AadClient(AadClientBase): async def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_client_certificate_request(scopes, certificate, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) - async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessToken: + async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo: request = self._get_client_secret_request(scopes, secret, **kwargs) return await self._run_pipeline(request, **kwargs) - async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessToken: + async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo: request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) - async def obtain_token_by_refresh_token(self, scopes: Iterable[str], refresh_token: str, **kwargs) -> AccessToken: + async def obtain_token_by_refresh_token( + self, scopes: Iterable[str], refresh_token: str, **kwargs + ) -> AccessTokenInfo: request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return await self._run_pipeline(request, **kwargs) @@ -62,7 +64,7 @@ class AadClient(AadClientBase): client_credential: Union[str, AadClientCertificate, Dict[str, Any]], refresh_token: str, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_refresh_token_on_behalf_of_request( scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs ) @@ -74,7 +76,7 @@ class AadClient(AadClientBase): client_credential: Union[str, AadClientCertificate, Dict[str, Any]], user_assertion: str, **kwargs - ) -> AccessToken: + ) -> AccessTokenInfo: request = self._get_on_behalf_of_request( scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs ) @@ -83,7 +85,7 @@ class AadClient(AadClientBase): def _build_pipeline(self, **kwargs) -> AsyncPipeline: return build_async_pipeline(**kwargs) - async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessToken: + async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index d9ceaf4e861..61ad01eaab7 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -7,7 +7,7 @@ import logging import time from typing import Any, Optional -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from ..._internal import within_credential_chain @@ -22,7 +22,7 @@ class GetTokenMixin(abc.ABC): super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: """Attempt to acquire an access token from a cache or by redeeming a refresh token. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -30,11 +30,11 @@ class GetTokenMixin(abc.ABC): https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes if successful; otherwise, None. - :rtype: ~azure.core.credentials.AccessToken or None + :rtype: ~azure.core.credentials.AccessTokenInfo or None """ @abc.abstractmethod - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: """Request an access token from the STS. :param str scopes: desired scopes for the access token. This method requires at least one scope. @@ -42,11 +42,13 @@ class GetTokenMixin(abc.ABC): https://learn.microsoft.com/entra/identity-platform/scopes-oidc. :return: An access token with the desired scopes. - :rtype: ~azure.core.credentials.AccessToken + :rtype: ~azure.core.credentials.AccessTokenInfo """ - def _should_refresh(self, token: AccessToken) -> bool: + def _should_refresh(self, token: AccessTokenInfo) -> bool: now = int(time.time()) + if token.refresh_on is not None and now >= token.refresh_on: + return True if token.expires_on - now > DEFAULT_REFRESH_OFFSET: return False if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY: @@ -59,7 +61,7 @@ class GetTokenMixin(abc.ABC): claims: Optional[str] = None, tenant_id: Optional[str] = None, enable_cae: bool = False, - **kwargs: Any + **kwargs: Any, ) -> AccessToken: """Request an access token for `scopes`. @@ -81,8 +83,50 @@ class GetTokenMixin(abc.ABC): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` attribute gives a reason. """ + options: TokenRequestOptions = {} + if claims: + options["claims"] = claims + if tenant_id: + options["tenant_id"] = tenant_id + options["enable_cae"] = enable_cae + + token_info = await self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) + return AccessToken(token_info.token, token_info.expires_on) + + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + """Request an access token for `scopes`. + + This is an alternative to `get_token` to enable certain scenarios that require additional properties + on the token. This method is called automatically by Azure SDK clients. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. + :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. + :paramtype options: ~azure.core.credentials.TokenRequestOptions + + :rtype: AccessTokenInfo + :return: An AccessTokenInfo instance containing information about the token. + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + return await self._get_token_base(*scopes, options=options, base_method_name="get_token_info") + + async def _get_token_base( + self, + *scopes: str, + options: Optional[TokenRequestOptions] = None, + base_method_name: str = "get_token_info", + **kwargs: Any, + ) -> AccessTokenInfo: if not scopes: - raise ValueError('"get_token" requires at least one scope') + raise ValueError(f'"{base_method_name}" requires at least one scope') + + options = options or {} + claims = options.get("claims") + tenant_id = options.get("tenant_id") + enable_cae = options.get("enable_cae", False) try: token = await self._acquire_token_silently( @@ -103,16 +147,18 @@ class GetTokenMixin(abc.ABC): pass _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, - "%s.get_token succeeded", + "%s.%s succeeded", self.__class__.__name__, + base_method_name, ) return token except Exception as ex: _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.WARNING, - "%s.get_token failed: %s", + "%s.%s failed: %s", self.__class__.__name__, + base_method_name, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py index 1bebc28fb5c..636fbbf9b2f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -6,7 +6,7 @@ import abc from types import TracebackType from typing import Any, cast, Optional, TypeVar, Type -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from . import AsyncContextManager from .get_token_mixin import GetTokenMixin from .managed_identity_client import AsyncManagedIdentityClient @@ -54,10 +54,15 @@ class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin): raise CredentialUnavailableError(message=self.get_unavailable_message()) return await super().get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: + if not self._client: + raise CredentialUnavailableError(message=self.get_unavailable_message()) + return await super().get_token_info(*scopes, options=options) + + async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py index 05503316d85..340bb5a11fa 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_client.py @@ -5,7 +5,7 @@ import time from typing import TypeVar -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo from azure.core.pipeline import AsyncPipeline from .._internal import AsyncContextManager from ..._internal import _scopes_to_resource @@ -24,7 +24,7 @@ class AsyncManagedIdentityClient(AsyncContextManager, ManagedIdentityClientBase) async def close(self) -> None: await self._pipeline.__aexit__() - async def request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: # pylint:disable=invalid-overridden-method resource = _scopes_to_resource(*scopes) request = self._request_factory(resource, self._identity_config) diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index e16249897f8..b56ebc2abaa 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -59,9 +59,9 @@ setup( ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.23.0", + "azure-core>=1.31.0", "cryptography>=2.5", - "msal>=1.29.0", + "msal>=1.30.0", "msal-extensions>=1.2.0", "typing-extensions>=4.0.0", ], diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 23378598869..1a654313396 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -6,15 +6,15 @@ import base64 import json import time from urllib.parse import urlparse +from unittest import mock -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock # type: ignore +from azure.core.credentials import AccessToken, AccessTokenInfo FAKE_CLIENT_ID = "fake-client-id" INVALID_CHARACTERS = "|\\`;{&' " +ACCESS_TOKEN_CLASSES = (AccessToken, AccessTokenInfo) +GET_TOKEN_METHODS = ("get_token", "get_token_info") def build_id_token( diff --git a/sdk/identity/azure-identity/tests/helpers_async.py b/sdk/identity/azure-identity/tests/helpers_async.py index 308b7db75b1..2a66e167c32 100644 --- a/sdk/identity/azure-identity/tests/helpers_async.py +++ b/sdk/identity/azure-identity/tests/helpers_async.py @@ -10,15 +10,6 @@ from unittest import mock from helpers import validating_transport -def await_test(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - loop.run_until_complete(fn(*args, **kwargs)) - - return wrapper - - def get_completed_future(result=None): future = asyncio.Future() future.set_result(result) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index bd0dc721eb3..b9f8c374d34 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import functools +from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.identity._constants import EnvironmentVariables @@ -15,11 +16,6 @@ from urllib.parse import urlparse from helpers import build_aad_response, mock_response from test_certificate_credential import PEM_CERT_PATH -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - BASE_CLASS_METHODS = [ ("_get_auth_code_request", ("code", "redirect_uri")), diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index f618669fba3..56a2f1486a8 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -180,7 +180,7 @@ async def test_request_url(authority): await client.obtain_token_by_refresh_token("scope", "refresh token") # obtain_token_by_refresh_token is client_secret safe - client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") + await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): diff --git a/sdk/identity/azure-identity/tests/test_app_service_async.py b/sdk/identity/azure-identity/tests/test_app_service_async.py index 5d32df81c74..e79964c39a7 100644 --- a/sdk/identity/azure-identity/tests/test_app_service_async.py +++ b/sdk/identity/azure-identity/tests/test_app_service_async.py @@ -11,7 +11,6 @@ import pytest from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async -from helpers_async import await_test from recorded_test_case import RecordedTestCase from test_app_service import PLAYBACK_URL @@ -33,7 +32,7 @@ class TestAppServiceAsync(RecordedTestCase): self.patch = mock.patch.dict(os.environ, env, clear=True) @pytest.mark.manual - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_system_assigned(self): self.load_settings() @@ -44,7 +43,7 @@ class TestAppServiceAsync(RecordedTestCase): assert isinstance(token.expires_on, int) @pytest.mark.manual - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_system_assigned_tenant_id(self): with self.patch: @@ -55,7 +54,7 @@ class TestAppServiceAsync(RecordedTestCase): @pytest.mark.manual @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_user_assigned(self): self.load_settings() @@ -67,7 +66,7 @@ class TestAppServiceAsync(RecordedTestCase): @pytest.mark.manual @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test + @pytest.mark.asyncio @recorded_by_proxy_async async def test_user_assigned_tenant_id(self): with self.patch: diff --git a/sdk/identity/azure-identity/tests/test_application_credential.py b/sdk/identity/azure-identity/tests/test_application_credential.py index c7a8a84f9ff..cd761d17aef 100644 --- a/sdk/identity/azure-identity/tests/test_application_credential.py +++ b/sdk/identity/azure-identity/tests/test_application_credential.py @@ -3,23 +3,21 @@ # Licensed under the MIT License. # ------------------------------------ import os +from unittest.mock import Mock, patch +from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables import pytest from urllib.parse import urlparse -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - -from helpers import build_aad_response, get_discovery_response, mock_response +from helpers import build_aad_response, get_discovery_response, mock_response, GET_TOKEN_METHODS -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): expected_token = "***" def send(request, **kwargs): @@ -35,34 +33,42 @@ def test_get_token(): with patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = AzureApplicationCredential(transport=Mock(send=send)) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token -def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_iterates_only_once(get_token_method): """When a credential succeeds, AzureApplicationCredential should use that credential thereafter""" - expected_token = AccessToken("***", 42) + access_token = "***" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), + ) + successful_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken(access_token, 42)), + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), ) - successful_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)) credential = AzureApplicationCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - token = credential.get_token("scope") - assert token.token == expected_token.token - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + token = getattr(credential, get_token_method)("scope") + assert token.token == access_token + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) diff --git a/sdk/identity/azure-identity/tests/test_application_credential_async.py b/sdk/identity/azure-identity/tests/test_application_credential_async.py index a317c840308..58a4a64edab 100644 --- a/sdk/identity/azure-identity/tests/test_application_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_application_credential_async.py @@ -5,44 +5,50 @@ import os from unittest.mock import Mock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity.aio._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables import pytest from urllib.parse import urlparse -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS from helpers_async import get_completed_future @pytest.mark.asyncio -async def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_iterates_only_once(get_token_method): """When a credential succeeds, AzureApplicationCredential should use that credential thereafter""" - expected_token = AccessToken("***", 42) + access_token = "***" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), ) successful_credential = Mock( - spec_set=["get_token"], get_token=Mock(return_value=get_completed_future(expected_token)) + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=get_completed_future(AccessToken(access_token, 42))), + get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo(access_token, 42))), ) credential = AzureApplicationCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - token = await credential.get_token("scope") - assert token.token == expected_token.token - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + token = await getattr(credential, get_token_method)("scope") + assert token.token == access_token + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -83,7 +89,8 @@ def test_authority(authority): @pytest.mark.asyncio -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): expected_token = "***" async def send(request, **kwargs): @@ -95,7 +102,7 @@ async def test_get_token(): with patch.dict("os.environ", {var: "..." for var in EnvironmentVariables.CLIENT_SECRET_VARS}, clear=True): credential = AzureApplicationCredential(transport=Mock(send=send)) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py index a4ee062e84d..0d2f5f4d881 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code.py +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -2,32 +2,30 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.core.exceptions import ClientAuthenticationError +from unittest.mock import Mock, patch +from urllib.parse import urlparse + from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import AuthorizationCodeCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import msal import pytest -from urllib.parse import urlparse -from helpers import build_aad_response, mock_response, Request, validating_transport - -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from helpers import build_aad_response, mock_response, Request, validating_transport, GET_TOKEN_METHODS -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = AuthorizationCodeCredential("tenant-id", "client-id", "auth-code", "http://localhost") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) def send(*_, **kwargs): @@ -40,12 +38,13 @@ def test_policies_configurable(): "tenant-id", "client-id", "auth-code", "http://localhost", policies=[policy], transport=Mock(send=send) ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -55,10 +54,11 @@ def test_user_agent(): "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -73,10 +73,14 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) -def test_auth_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_auth_code_credential(get_token_method): client_id = "client id" secret = "fake-client-secret" tenant_id = "tenant" @@ -126,24 +130,25 @@ def test_auth_code_credential(): ) # first call should redeem the auth code - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code -> credential should return cached token - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code, no cached token -> credential should redeem refresh token cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0] cache.remove_at(cached_access_token) - token = credential.get_token(expected_scope) + token = getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -167,21 +172,28 @@ def test_multitenant_authentication(): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -203,15 +215,24 @@ def test_multitenant_authentication_not_allowed(): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index 2d8c984851e..04f86d510ca 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -13,21 +13,23 @@ from azure.identity.aio import AuthorizationCodeCredential import msal import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport pytestmark = pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = AuthorizationCodeCredential("tenant-id", "client-id", "auth-code", "http://localhost") with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -40,7 +42,7 @@ async def test_policies_configurable(): "tenant-id", "client-id", "auth-code", "http://localhost", policies=[policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -69,7 +71,8 @@ async def test_context_manager(): assert transport.__aexit__.call_count == 1 -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -79,10 +82,11 @@ async def test_user_agent(): "tenant-id", "client-id", "auth-code", "http://localhost", transport=transport ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -97,10 +101,14 @@ async def test_tenant_id(): additionally_allowed_tenants=["*"], ) - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) -async def test_auth_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_auth_code_credential(get_token_method): client_id = "client id" secret = "fake-client-secret" tenant_id = "tenant" @@ -150,24 +158,25 @@ async def test_auth_code_credential(): ) # first call should redeem the auth code - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code -> credential should return cached token - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 1 # no auth code, no cached token -> credential should redeem refresh token cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0] cache.remove_at(cached_access_token) - token = await credential.get_token(expected_scope) + token = await getattr(credential, get_token_method)(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -191,21 +200,28 @@ async def test_multitenant_authentication(): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -227,15 +243,24 @@ async def test_multitenant_authentication_not_allowed(): additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_authority.py b/sdk/identity/azure-identity/tests/test_authority.py index d6a31b627f0..725c481ff5c 100644 --- a/sdk/identity/azure-identity/tests/test_authority.py +++ b/sdk/identity/azure-identity/tests/test_authority.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from unittest.mock import patch from azure.identity._constants import EnvironmentVariables, KnownAuthorities from azure.identity._internal import get_default_authority, normalize_authority diff --git a/sdk/identity/azure-identity/tests/test_azd_cli_credential.py b/sdk/identity/azure-identity/tests/test_azd_cli_credential.py index 0bceaaba3c0..44af829dfec 100644 --- a/sdk/identity/azure-identity/tests/test_azd_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_azd_cli_credential.py @@ -14,7 +14,7 @@ from azure.core.exceptions import ClientAuthenticationError import subprocess import pytest -from helpers import mock, INVALID_CHARACTERS +from helpers import mock, INVALID_CHARACTERS, GET_TOKEN_METHODS CHECK_OUTPUT = AzureDeveloperCliCredential.__module__ + ".subprocess.check_output" @@ -35,14 +35,16 @@ def raise_called_process_error(return_code, output="", cmd="...", stderr=""): return mock.Mock(side_effect=error) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token() + getattr(AzureDeveloperCliCredential(), get_token_method)() -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -50,21 +52,26 @@ def test_invalid_tenant_id(): AzureDeveloperCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzureDeveloperCliCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope" + c) + getattr(AzureDeveloperCliCredential(), get_token_method)("scope" + c) with pytest.raises(ValueError): - AzureDeveloperCliCredential().get_token("scope", "scope2", "scope" + c) + getattr(AzureDeveloperCliCredential(), get_token_method)("scope", "scope2", "scope" + c) -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): """The credential should parse the CLI's output to an token""" access_token = "access token" @@ -81,40 +88,44 @@ def test_get_token(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureDeveloperCliCredential().get_token("scope") + token = getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: not logged in, run `azd auth login` to login" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_aadsts_error(get_token_method): """When there is an AADSTS error, the credential should raise an error containing the CLI's output even if the error also contains the 'not logged in' string.""" @@ -122,46 +133,50 @@ def test_aadsts_error(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(42, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, output=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureDeveloperCliCredential().get_token("scope") + getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess times out""" from subprocess import TimeoutExpired @@ -169,7 +184,7 @@ def test_timeout(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))) as check_output_mock: with pytest.raises(CredentialUnavailableError): - AzureDeveloperCliCredential(process_timeout=42).get_token("scope") + getattr(AzureDeveloperCliCredential(process_timeout=42), get_token_method)("scope") # Ensure custom timeout is passed to subprocess _, kwargs = check_output_mock.call_args @@ -177,7 +192,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_multitenant_authentication_class(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_class(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -199,17 +215,18 @@ def test_multitenant_authentication_class(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = AzureDeveloperCliCredential().get_token("scope") + token = getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert token.token == first_token - token = AzureDeveloperCliCredential(tenant_id=default_tenant).get_token("scope") + token = getattr(AzureDeveloperCliCredential(tenant_id=default_tenant), get_token_method)("scope") assert token.token == first_token - token = AzureDeveloperCliCredential(tenant_id=second_tenant).get_token("scope") + token = getattr(AzureDeveloperCliCredential(tenant_id=second_tenant), get_token_method)("scope") assert token.token == second_token -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -232,21 +249,28 @@ def test_multitenant_authentication(): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -266,9 +290,12 @@ def test_multitenant_authentication_not_allowed(): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py index 267a88ab85d..95f7b2b2b14 100644 --- a/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_azd_cli_credential_async.py @@ -16,7 +16,7 @@ from azure.identity._credentials.azd_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError import pytest -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_azd_cli_credential import TEST_ERROR_OUTPUTS @@ -33,14 +33,16 @@ def mock_exec(stdout, stderr="", return_code=0): return mock.Mock(return_value=get_completed_future(process)) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token() + await getattr(AzureDeveloperCliCredential(), get_token_method)() -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -48,18 +50,22 @@ async def test_invalid_tenant_id(): AzureDeveloperCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope" + c) + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope" + c) with pytest.raises(ValueError): - await AzureDeveloperCliCredential().get_token("scope", "scope2", "scope" + c) + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope", "scope2", "scope" + c) async def test_close(): @@ -76,21 +82,25 @@ async def test_context_manager(): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_fallback(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_fallback(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = mock.Mock() with mock.patch("azure.identity.aio._credentials.azd_cli._SyncAzureDeveloperCliCredential") as fallback: - fallback.return_value = mock.Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = mock.Mock( + spec_set=["get_token", "get_token_info"], get_token=sync_get_token, get_token_info=sync_get_token + ) with mock.patch(AzureDeveloperCliCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -108,33 +118,36 @@ async def test_get_token(): with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureDeveloperCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -async def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: not logged in, run `azd auth login` to login" @@ -142,10 +155,11 @@ async def test_not_logged_in(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_aadsts_error(get_token_method): """When there is an AADSTS error, the credential should raise an error containing the CLI's output even if the error also contains the 'not logged in' string.""" @@ -154,10 +168,11 @@ async def test_aadsts_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" @@ -165,50 +180,54 @@ async def test_unexpected_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=42)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) @pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output, return_code=1)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureDeveloperCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = mock.Mock(communicate=mock.Mock(side_effect=asyncio.TimeoutError), returncode=None) with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzureDeveloperCliCredential().get_token("scope") + await getattr(AzureDeveloperCliCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -232,21 +251,28 @@ async def test_multitenant_authentication(): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -267,9 +293,12 @@ async def test_multitenant_authentication_not_allowed(): credential = AzureDeveloperCliCredential() with mock.patch("shutil.which", return_value="azd"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_azure_application.py b/sdk/identity/azure-identity/tests/test_azure_application.py index 0945a1a1823..8772264ade7 100644 --- a/sdk/identity/azure-identity/tests/test_azure_application.py +++ b/sdk/identity/azure-identity/tests/test_azure_application.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import patch -except ImportError: - from mock import patch # type: ignore +from unittest.mock import patch from azure.identity._credentials.application import AzureApplicationCredential from azure.identity._constants import EnvironmentVariables diff --git a/sdk/identity/azure-identity/tests/test_azure_arc.py b/sdk/identity/azure-identity/tests/test_azure_arc.py index 2f216db452f..8234eecde7c 100644 --- a/sdk/identity/azure-identity/tests/test_azure_arc.py +++ b/sdk/identity/azure-identity/tests/test_azure_arc.py @@ -11,8 +11,11 @@ import msal from azure.core.exceptions import ClientAuthenticationError from azure.identity._credentials.azure_arc import AzureArcCredential +from helpers import GET_TOKEN_METHODS -def test_msal_managed_identity_error(): + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_msal_managed_identity_error(get_token_method): scopes = ["scope1"] def mock_request_token(*args, **kwargs): @@ -22,4 +25,4 @@ def test_msal_managed_identity_error(): cred._msal_client.acquire_token_for_client = mock_request_token with pytest.raises(ClientAuthenticationError): - cred.get_token(*scopes) + getattr(cred, get_token_method)(*scopes) diff --git a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py index 459f2e2251f..1f9e27dab2d 100644 --- a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py +++ b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential.py @@ -16,6 +16,8 @@ from azure.identity import ( ) from azure.identity._credentials.azure_pipelines import SYSTEM_OIDCREQUESTURI, OIDC_API_VERSION, build_oidc_request +from helpers import GET_TOKEN_METHODS + def test_azure_pipelines_credential_initialize(): system_access_token = "token" @@ -76,7 +78,8 @@ def test_build_oidc_request(): assert request.headers["Authorization"] == f"Bearer {access_token}" -def test_azure_pipelines_credential_missing_system_env_var(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_missing_system_env_var(get_token_method): credential = AzurePipelinesCredential( system_access_token="token", client_id="client-id", @@ -86,11 +89,12 @@ def test_azure_pipelines_credential_missing_system_env_var(): with patch.dict("os.environ", {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert f"Missing value for the {SYSTEM_OIDCREQUESTURI} environment variable" in str(ex.value) -def test_azure_pipelines_credential_in_chain(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_in_chain(get_token_method): mock_credential = MagicMock() with patch.dict("os.environ", {}, clear=True): @@ -103,12 +107,13 @@ def test_azure_pipelines_credential_in_chain(): ), mock_credential, ) - chain_credential.get_token("scope") - assert mock_credential.get_token.called + getattr(chain_credential, get_token_method)("scope") + assert getattr(mock_credential, get_token_method).called @pytest.mark.live_test_only("Requires Azure Pipelines environment with configured service connection") -def test_azure_pipelines_credential_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_pipelines_credential_authentication(get_token_method): system_access_token = os.environ.get("SYSTEM_ACCESSTOKEN", "") service_connection_id = os.environ.get("AZURE_SERVICE_CONNECTION_ID", "") tenant_id = os.environ.get("AZURE_SERVICE_CONNECTION_TENANT_ID", "") @@ -126,6 +131,6 @@ def test_azure_pipelines_credential_authentication(): service_connection_id=service_connection_id, ) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token assert isinstance(token.expires_on, int) diff --git a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py index cc3cf313a88..471bc4f3218 100644 --- a/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_azure_pipelines_credential_async.py @@ -11,6 +11,8 @@ from azure.identity import CredentialUnavailableError from azure.identity._credentials.azure_pipelines import SYSTEM_OIDCREQUESTURI from azure.identity.aio import AzurePipelinesCredential, ChainedTokenCredential, ClientAssertionCredential +from helpers import GET_TOKEN_METHODS + def test_azure_pipelines_credential_initialize(): system_access_token = "token" @@ -57,7 +59,8 @@ async def test_azure_pipelines_credential_context_manager(): @pytest.mark.asyncio -async def test_azure_pipelines_credential_missing_system_env_var(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_missing_system_env_var(get_token_method): credential = AzurePipelinesCredential( system_access_token="token", client_id="client-id", @@ -67,12 +70,13 @@ async def test_azure_pipelines_credential_missing_system_env_var(): with patch.dict("os.environ", {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert f"Missing value for the {SYSTEM_OIDCREQUESTURI} environment variable" in str(ex.value) @pytest.mark.asyncio -async def test_azure_pipelines_credential_in_chain(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_in_chain(get_token_method): mock_credential = AsyncMock() with patch.dict("os.environ", {}, clear=True): @@ -85,13 +89,14 @@ async def test_azure_pipelines_credential_in_chain(): ), mock_credential, ) - await chain_credential.get_token("scope") - assert mock_credential.get_token.called + await getattr(chain_credential, get_token_method)("scope") + assert getattr(mock_credential, get_token_method).called @pytest.mark.asyncio @pytest.mark.live_test_only("Requires Azure Pipelines environment with configured service connection") -async def test_azure_pipelines_credential_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_pipelines_credential_authentication(get_token_method): system_access_token = os.environ.get("SYSTEM_ACCESSTOKEN", "") service_connection_id = os.environ.get("AZURE_SERVICE_CONNECTION_ID", "") tenant_id = os.environ.get("AZURE_SERVICE_CONNECTION_TENANT_ID", "") @@ -109,6 +114,6 @@ async def test_azure_pipelines_credential_authentication(): service_connection_id=service_connection_id, ) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token assert isinstance(token.expires_on, int) diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py index f20ce7ac1d8..89cd5248645 100644 --- a/sdk/identity/azure-identity/tests/test_bearer_token_provider.py +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import get_bearer_token_provider @@ -14,7 +14,17 @@ class MockCredential: return AccessToken("mock_token", 42) +class MockCredentialTokenInfo: + def get_token_info(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessTokenInfo("mock_token_2", 42) + + def test_get_bearer_token_provider(): func = get_bearer_token_provider(MockCredential(), "scope") assert func() == "mock_token" + + func = get_bearer_token_provider(MockCredentialTokenInfo(), "scope") # type: ignore + assert func() == "mock_token_2" diff --git a/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py index 35a8db46457..9fa6f67d575 100644 --- a/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py +++ b/sdk/identity/azure-identity/tests/test_bearer_token_provider_async.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity.aio import get_bearer_token_provider import pytest @@ -16,8 +16,18 @@ class MockCredential: return AccessToken("mock_token", 42) +class MockCredentialTokenInfo: + async def get_token_info(self, *scopes, **kwargs): + assert len(scopes) == 1 + assert scopes[0] == "scope" + return AccessTokenInfo("mock_token_2", 42) + + @pytest.mark.asyncio async def test_get_bearer_token_provider(): - func = get_bearer_token_provider(MockCredential(), "scope") + func = get_bearer_token_provider(MockCredential(), "scope") # type: ignore assert await func() == "mock_token" + + func = get_bearer_token_provider(MockCredentialTokenInfo(), "scope") # type: ignore + assert await func() == "mock_token_2" diff --git a/sdk/identity/azure-identity/tests/test_browser_credential.py b/sdk/identity/azure-identity/tests/test_browser_credential.py index 6cef0366c09..760af0a2a8d 100644 --- a/sdk/identity/azure-identity/tests/test_browser_credential.py +++ b/sdk/identity/azure-identity/tests/test_browser_credential.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import platform import random import socket import threading @@ -18,13 +17,11 @@ import pytest from unittest.mock import ANY, Mock, patch from helpers import ( - build_aad_response, - build_id_token, get_discovery_response, - id_token_claims, mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) @@ -32,7 +29,8 @@ WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open" @pytest.mark.manual -def test_browser_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_browser_credential(get_token_method): transport = Mock(wraps=RequestsTransport()) credential = InteractiveBrowserCredential(transport=transport) scope = "https://management.azure.com/.default" # N.B. this is valid only in Public Cloud @@ -45,15 +43,15 @@ def test_browser_credential(): # credential should have a cached access token for the scope used in authenticate with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token credential = InteractiveBrowserCredential(transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): - second_token = credential.get_token(scope) + second_token = getattr(credential, get_token_method)(scope) assert second_token.token == token.token # every request should have the correct User-Agent @@ -76,14 +74,16 @@ def test_tenant_id_validation(): InteractiveBrowserCredential(tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" with pytest.raises(ValueError): - InteractiveBrowserCredential().get_token() + getattr(InteractiveBrowserCredential(), get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): # the policy raises an exception so this test can run without authenticating i.e. opening a browser expected_message = "test_policies_configurable" policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(side_effect=Exception(expected_message))) @@ -91,13 +91,14 @@ def test_policies_configurable(): credential = InteractiveBrowserCredential(policies=[policy]) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message assert policy.on_request.called -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When configured for strict silent auth, the credential should raise when silent auth fails""" transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) @@ -107,10 +108,11 @@ def test_disable_automatic_authentication(): with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential shouldn't try interactive authentication"))): with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """get_token should raise ClientAuthenticationError when the server times out without receiving a redirect""" timeout = 0.01 @@ -133,11 +135,12 @@ def test_timeout(): with patch(WEBBROWSER_OPEN, lambda _: True): with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert "timed out" in ex.value.message.lower() -def test_redirect_server(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_redirect_server(get_token_method): # binding a random port prevents races when running the test in parallel server = None hostname = "127.0.0.1" @@ -167,7 +170,8 @@ def test_redirect_server(): assert server.query_params[expected_param] == expected_value -def test_no_browser(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_browser(get_token_method): """The credential should raise CredentialUnavailableError when it can't open a browser""" transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2) @@ -176,10 +180,11 @@ def test_no_browser(): ) with patch(InteractiveBrowserCredential.__module__ + "._open_browser", lambda _: False): with pytest.raises(CredentialUnavailableError, match=r".*browser.*"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_redirect_uri(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_redirect_uri(get_token_method): """The credential should configure the redirect server to use a given redirect_uri""" expected_hostname = "localhost" @@ -192,7 +197,7 @@ def test_redirect_uri(): client_credential="client_credential", ) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message server.assert_called_once_with(expected_hostname, expected_port, timeout=ANY) @@ -206,17 +211,19 @@ def test_invalid_redirect_uri(redirect_uri): InteractiveBrowserCredential(redirect_uri=redirect_uri, client_credential="client_credential") -def test_cannot_bind_port(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_bind_port(get_token_method): """get_token should raise CredentialUnavailableError when the redirect listener can't bind a port""" credential = InteractiveBrowserCredential( _server_class=Mock(side_effect=socket.error), client_credential="client_credential" ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_cannot_bind_redirect_uri(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_bind_redirect_uri(get_token_method): """When a user specifies a redirect URI, the credential shouldn't attempt to bind another""" server = Mock(side_effect=socket.error) @@ -225,6 +232,6 @@ def test_cannot_bind_redirect_uri(): ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") server.assert_called_once_with("localhost", 42, timeout=ANY) diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index a9352b85a57..a00443db404 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -4,8 +4,8 @@ # ------------------------------------ import json import os +from unittest.mock import Mock, patch -from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import CertificateCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority @@ -17,7 +17,6 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding from msal import TokenCache -from msal_extensions import PersistedTokenCache import msal import pytest from urllib.parse import urlparse @@ -29,16 +28,11 @@ from helpers import ( get_discovery_response, urlsafeb64_decode, mock_response, - msal_validating_transport, new_msal_validating_transport, Request, + GET_TOKEN_METHODS, ) -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - PEM_CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pem") PEM_CERT_WITH_PASSWORD_PATH = os.path.join(os.path.dirname(__file__), "certificate-with-password.pem") PFX_CERT_PATH = os.path.join(os.path.dirname(__file__), "certificate.pfx") @@ -77,15 +71,17 @@ def test_tenant_id_validation(): CertificateCredential(tenant, "client-id", PEM_CERT_PATH) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = new_msal_validating_transport( @@ -96,12 +92,13 @@ def test_policies_configurable(): "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -109,10 +106,11 @@ def test_user_agent(): credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH, transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -122,11 +120,15 @@ def test_tenant_id(): "tenant-id", "client-id", PEM_CERT_PATH, transport=transport, additionally_allowed_tenants=["*"] ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -141,7 +143,7 @@ def test_authority(authority): credential = CertificateCredential(tenant_id, "client-id", PEM_CERT_PATH, authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -152,14 +154,15 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = CertificateCredential(tenant_id, "client-id", PEM_CERT_PATH, authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args assert kwargs["authority"] == expected_authority -def test_regional_authority(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_regional_authority(get_token_method): """the credential should configure MSAL with a regional authority specified via kwarg or environment variable""" mock_confidential_client = Mock( @@ -173,7 +176,7 @@ def test_regional_authority(): with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region.value}, clear=True): credential = CertificateCredential("tenant", "client-id", PEM_CERT_PATH) with patch("msal.ConfidentialClientApplication", mock_confidential_client): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_confidential_client.call_count == 1 _, kwargs = mock_confidential_client.call_args @@ -200,7 +203,8 @@ def test_requires_certificate(): @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) @pytest.mark.parametrize("send_certificate_chain", (True, False)) -def test_request_body(cert_path, cert_password, send_certificate_chain): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_request_body(cert_path, cert_password, send_certificate_chain, get_token_method): access_token = "***" authority = "authority.com" client_id = "client-id" @@ -228,7 +232,7 @@ def test_request_body(cert_path, cert_password, send_certificate_chain): authority=authority, send_certificate_chain=send_certificate_chain, ) - token = cred.get_token(expected_scope) + token = getattr(cred, get_token_method)(expected_scope) assert token.token == access_token # credential should also accept the certificate as bytes @@ -244,7 +248,7 @@ def test_request_body(cert_path, cert_password, send_certificate_chain): authority=authority, send_certificate_chain=send_certificate_chain, ) - token = cred.get_token(expected_scope) + token = getattr(cred, get_token_method)(expected_scope) assert token.token == access_token @@ -294,7 +298,8 @@ def validate_jwt(request, client_id, cert_bytes, cert_password, expect_x5c=False @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_token_cache_persistent(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(cert_path, cert_password, get_token_method): """the credential should use a persistent cache if cache_persistence_options are configured""" access_token = "foo token" @@ -324,19 +329,23 @@ def test_token_cache_persistent(cert_path, cert_password): assert credential._cache is None assert credential._cae_cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_token_cache_memory(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_memory(cert_path, cert_password, get_token_method): """The credential should default to in-memory cache if no persistence options are provided.""" access_token = "foo token" @@ -356,19 +365,23 @@ def test_token_cache_memory(cert_path, cert_password): ) assert credential._cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert isinstance(credential._cache, TokenCache) assert credential._cae_cache is None assert not load_persistent_cache.called - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert isinstance(credential._cae_cache, TokenCache) assert not load_persistent_cache.called @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_persistent_cache_multiple_clients(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_persistent_cache_multiple_clients(cert_path, cert_password, get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -403,13 +416,13 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password): # A caches a token scope = "scope" - token_a = credential_a.get_token(scope) + token_a = getattr(credential_a, get_token_method)(scope) assert mock_cache_loader.call_count == 1, "credential should use the persistent cache" assert token_a.token == access_token_a assert transport_a.send.call_count == 2 # one MSAL discovery request, one token request # B should get a different token for the same scope - token_b = credential_b.get_token(scope) + token_b = getattr(credential_b, get_token_method)(scope) assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" assert token_b.token == access_token_b assert transport_b.send.call_count == 2 @@ -427,7 +440,8 @@ def test_certificate_arguments(): @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_multitenant_authentication(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(cert_path, cert_password, get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -454,22 +468,29 @@ def test_multitenant_authentication(cert_path, cert_password): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -def test_multitenant_authentication_backcompat(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_backcompat(cert_path, cert_password, get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -494,14 +515,20 @@ def test_multitenant_authentication_backcompat(cert_path, cert_password): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # explicitly specifying the configured tenant is okay - token = credential.get_token("scope", tenant_id=expected_tenant) + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token @@ -524,7 +551,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -540,7 +568,10 @@ def test_claims_challenge(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 79786bb0216..a55d593e873 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -5,7 +5,6 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -15,7 +14,7 @@ from azure.identity.aio import CertificateCredential from msal import TokenCache import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_certificate_credential import ALL_CERTS, EC_CERT_PATH, PEM_CERT_PATH, validate_jwt @@ -42,12 +41,13 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -73,7 +73,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -86,13 +87,14 @@ async def test_policies_configurable(): "tenant-id", "client-id", PEM_CERT_PATH, policies=[ContentDecodePolicy(), policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -100,11 +102,12 @@ async def test_user_agent(): credential = CertificateCredential("tenant-id", "client-id", PEM_CERT_PATH, transport=transport) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -113,14 +116,17 @@ async def test_tenant_id(): credential = CertificateCredential( "tenant-id", "client-id", PEM_CERT_PATH, transport=transport, additionally_allowed_tenants=["*"] ) - - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.asyncio @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_request_url(cert_path, cert_password, authority): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_request_url(cert_path, cert_password, authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -138,7 +144,7 @@ async def test_request_url(cert_path, cert_password, authority): cred = CertificateCredential( tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send), authority=authority ) - token = await cred.get_token("scope") + token = await getattr(cred, get_token_method)("scope") assert token.token == access_token # authority can be configured via environment variable @@ -146,7 +152,7 @@ async def test_request_url(cert_path, cert_password, authority): credential = CertificateCredential( tenant_id, "client-id", cert_path, password=cert_password, transport=Mock(send=mock_send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert token.token == access_token @@ -167,7 +173,8 @@ def test_requires_certificate(): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_request_body(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_request_body(cert_path, cert_password, get_token_method): access_token = "***" authority = "authority.com" client_id = "client-id" @@ -186,7 +193,7 @@ async def test_request_body(cert_path, cert_password): cred = CertificateCredential( tenant_id, client_id, cert_path, password=cert_password, transport=Mock(send=mock_send), authority=authority ) - token = await cred.get_token(expected_scope) + token = await getattr(cred, get_token_method)(expected_scope) assert token.token == access_token # credential should also accept the certificate as bytes @@ -201,13 +208,14 @@ async def test_request_body(cert_path, cert_password): transport=Mock(send=mock_send), authority=authority, ) - token = await cred.get_token(expected_scope) + token = await getattr(cred, get_token_method)(expected_scope) assert token.token == access_token @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_token_cache_memory(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_memory(cert_path, cert_password, get_token_method): """the credential should optionally use a persistent cache, and default to an in memory cache""" access_token = "token" @@ -227,13 +235,16 @@ async def test_token_cache_memory(cert_path, cert_password): assert not mock_token_cache.called assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_token_cache.call_count == 1 assert load_persistent_cache.call_count == 0 assert credential._client._cache is not None assert credential._client._cae_cache is None - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert mock_token_cache.call_count == 2 assert load_persistent_cache.call_count == 0 assert credential._client._cae_cache is not None @@ -241,7 +252,8 @@ async def test_token_cache_memory(cert_path, cert_password): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_token_cache_persistent(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistent(cert_path, cert_password, get_token_method): """the credential should optionally use a persistent cache, and default to an in memory cache""" access_token = "token" @@ -267,14 +279,17 @@ async def test_token_cache_persistent(cert_path, cert_password): assert not mock_token_cache.called assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None args, _ = load_persistent_cache.call_args assert args[1] is False - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None args, _ = load_persistent_cache.call_args @@ -284,7 +299,8 @@ async def test_token_cache_persistent(cert_path, cert_password): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_persistent_cache_multiple_clients(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_persistent_cache_multiple_clients(cert_path, cert_password, get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -321,7 +337,7 @@ async def test_persistent_cache_multiple_clients(cert_path, cert_password): # A caches a token scope = "scope" - token_a = await credential_a.get_token(scope) + token_a = await getattr(credential_a, get_token_method)(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 1 assert mock_cache_loader.call_count == 1 @@ -329,7 +345,7 @@ async def test_persistent_cache_multiple_clients(cert_path, cert_password): assert args[1] is False # not CAE # B should get a different token for the same scope - token_b = await credential_b.get_token(scope) + token_b = await getattr(credential_b, get_token_method)(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 1 assert mock_cache_loader.call_count == 2 @@ -348,7 +364,8 @@ def test_certificate_arguments(): @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_multitenant_authentication(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(cert_path, cert_password, get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -372,23 +389,30 @@ async def test_multitenant_authentication(cert_path, cert_password): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio @pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS) -async def test_multitenant_authentication_backcompat(cert_path, cert_password): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_backcompat(cert_path, cert_password, get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -410,12 +434,18 @@ async def test_multitenant_authentication_backcompat(cert_path, cert_password): additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # explicitly specifying the configured tenant is okay - token = await credential.get_token("scope", tenant_id=expected_tenant) + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 diff --git a/sdk/identity/azure-identity/tests/test_chained_credential.py b/sdk/identity/azure-identity/tests/test_chained_credential.py index ea9efbbf43f..0099802d0c6 100644 --- a/sdk/identity/azure-identity/tests/test_chained_credential.py +++ b/sdk/identity/azure-identity/tests/test_chained_credential.py @@ -5,7 +5,7 @@ import time from unittest.mock import Mock, MagicMock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from azure.identity._credentials.imds import IMDS_TOKEN_PATH, IMDS_AUTHORITY from azure.identity._internal.user_agent import USER_AGENT @@ -17,7 +17,7 @@ from azure.identity import ( ) import pytest -from helpers import validating_transport, Request, mock_response +from helpers import validating_transport, Request, mock_response, GET_TOKEN_METHODS def test_close(): @@ -51,74 +51,116 @@ def test_context_manager(): assert credential.__exit__.call_count == 1 -def test_error_message(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_error_message(get_token_method): first_error = "first_error" first_credential = Mock( - spec=ClientSecretCredential, get_token=Mock(side_effect=CredentialUnavailableError(first_error)) + spec=ClientSecretCredential, + get_token=Mock(side_effect=CredentialUnavailableError(first_error)), + get_token_info=Mock(side_effect=CredentialUnavailableError(first_error)), ) second_error = "second_error" second_credential = Mock( - name="second_credential", get_token=Mock(side_effect=ClientAuthenticationError(second_error)) + name="second_credential", + get_token=Mock(side_effect=ClientAuthenticationError(second_error)), + get_token_info=Mock(side_effect=ClientAuthenticationError(second_error)), ) with pytest.raises(ClientAuthenticationError) as ex: - ChainedTokenCredential(first_credential, second_credential).get_token("scope") + chained_cred = ChainedTokenCredential(first_credential, second_credential) + getattr(chained_cred, get_token_method)("scope") assert "ClientSecretCredential" in ex.value.message assert first_error in ex.value.message assert second_error in ex.value.message -def test_attempts_all_credentials(): - expected_token = AccessToken("expected_token", 0) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_attempts_all_credentials(get_token_method): + expected_token = "expected_token" + expires_on = 42 credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken(expected_token, expires_on)), + get_token_info=Mock(return_value=AccessTokenInfo(expected_token, expires_on)), + ), ] - token = ChainedTokenCredential(*credentials).get_token("scope") - assert token is expected_token + token = getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") + assert token.token == expected_token for credential in credentials: - assert credential.get_token.call_count == 1 + assert getattr(credential, get_token_method).call_count == 1 -def test_raises_for_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_raises_for_unexpected_error(get_token_method): """the chain should not continue after an unexpected error (i.e. anything but CredentialUnavailableError)""" expected_message = "it can't be done" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=ValueError(expected_message))), - Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("**", 42))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=ValueError(expected_message)), + get_token_info=Mock(side_effect=ValueError(expected_message)), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("**", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("**", 42)), + ), ] with pytest.raises(ClientAuthenticationError) as ex: - ChainedTokenCredential(*credentials).get_token("scope") + getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") assert expected_message in ex.value.message - assert credentials[-1].get_token.call_count == 0 + assert getattr(credentials[-1], get_token_method).call_count == 0 -def test_returns_first_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_returns_first_token(get_token_method): expected_token = Mock() - first_credential = Mock(spec_set=["get_token"], get_token=lambda _, **__: expected_token) - second_credential = Mock(spec_set=["get_token"], get_token=Mock()) + first_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=lambda _, **__: expected_token, + get_token_info=lambda _, **__: expected_token, + ) + second_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(), + get_token_info=Mock(), + ) aggregate = ChainedTokenCredential(first_credential, second_credential) - credential = aggregate.get_token("scope") + token = getattr(aggregate, get_token_method)("scope") - assert credential is expected_token - assert second_credential.get_token.call_count == 0 + assert token.token == expected_token.token + assert getattr(second_credential, get_token_method).call_count == 0 -def test_managed_identity_imds_probe(): - access_token = "****" +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_managed_identity_imds_probe(get_token_method): + expected_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = validating_transport( requests=[ @@ -135,7 +177,7 @@ def test_managed_identity_imds_probe(): mock_response(status_code=400, json_payload={"error": "this is an error message"}), mock_response( json_payload={ - "access_token": access_token, + "access_token": expected_token, "expires_in": 42, "expires_on": expires_on, "ext_expires_in": 42, @@ -149,28 +191,107 @@ def test_managed_identity_imds_probe(): with patch.dict("os.environ", clear=True): credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), ] - token = ChainedTokenCredential(*credentials).get_token(scope) - assert token.token == expected_token.token + token = getattr(ChainedTokenCredential(*credentials), get_token_method)(scope) + assert token.token == expected_token -def test_managed_identity_failed_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_managed_identity_failed_probe(get_token_method): mock_send = Mock(side_effect=Exception("timeout")) transport = Mock(send=mock_send) expected_token = Mock() credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), - Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=expected_token), + get_token_info=Mock(return_value=expected_token), + ), ] with patch.dict("os.environ", clear=True): - token = ChainedTokenCredential(*credentials).get_token("scope") + token = getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") - assert token is expected_token + assert token.token == expected_token.token # ManagedIdentityCredential should be tried and skipped with the last credential in the chain # being used. - assert credentials[-1].get_token.call_count == 1 + assert getattr(credentials[-1], get_token_method).call_count == 1 + + +def test_credentials_with_no_get_token_info(): + """ChainedTokenCredential should work with credentials that don't implement get_token_info.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(return_value=AccessToken(access_token, 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("foo", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("bar", 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token_info("scope") + assert token_info.token == access_token + + +def test_credentials_with_no_get_token(): + """ChainedTokenCredential should work with credentials that only implement get_token_info.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("foo", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("bar", 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token("scope") + assert token_info.token == access_token + + +def test_credentials_with_pop_option(): + """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(return_value=AccessToken("foo", 42)), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("bar", 42)), + get_token_info=Mock(return_value=AccessTokenInfo(access_token, 42)), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = chain.get_token_info("scope", options={"pop": True}) # type: ignore + assert token_info.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py index b145e3f93ba..fb668685669 100644 --- a/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_chained_token_credential_async.py @@ -5,7 +5,7 @@ import time from unittest.mock import Mock, patch -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError, ClientSecretCredential from azure.identity.aio import ChainedTokenCredential, ManagedIdentityCredential @@ -13,7 +13,7 @@ from azure.identity._credentials.imds import IMDS_TOKEN_PATH, IMDS_AUTHORITY from azure.identity._internal.user_agent import USER_AGENT import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import get_completed_future, wrap_in_future, async_validating_transport @@ -41,18 +41,23 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_credential_chain_error_message(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_credential_chain_error_message(get_token_method): first_error = "first_error" first_credential = Mock( - spec=ClientSecretCredential, get_token=Mock(side_effect=CredentialUnavailableError(first_error)) + spec=ClientSecretCredential, + get_token=Mock(side_effect=CredentialUnavailableError(first_error)), + get_token_info=Mock(side_effect=CredentialUnavailableError(first_error)), ) second_error = "second_error" second_credential = Mock( - name="second_credential", get_token=Mock(side_effect=ClientAuthenticationError(second_error)) + name="second_credential", + get_token=Mock(side_effect=ClientAuthenticationError(second_error)), + get_token_info=Mock(side_effect=ClientAuthenticationError(second_error)), ) with pytest.raises(ClientAuthenticationError) as ex: - await ChainedTokenCredential(first_credential, second_credential).get_token("scope") + await getattr(ChainedTokenCredential(first_credential, second_credential), get_token_method)("scope") assert "ClientSecretCredential" in ex.value.message assert first_error in ex.value.message @@ -60,26 +65,40 @@ async def test_credential_chain_error_message(): @pytest.mark.asyncio -async def test_chain_attempts_all_credentials(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_chain_attempts_all_credentials(get_token_method): async def credential_unavailable(message="it didn't work", **_): raise CredentialUnavailableError(message) - expected_token = AccessToken("expected_token", 0) + access_token = "expected_token" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=wrap_in_future(lambda _, **__: expected_token)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=wrap_in_future(lambda _, **__: AccessToken(access_token, 42)), + get_token_info=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42)), + ), ] - token = await ChainedTokenCredential(*credentials).get_token("scope") - assert token is expected_token + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") + assert token.token == access_token for credential in credentials[:-1]: - assert credential.get_token.call_count == 1 + assert getattr(credential, get_token_method).call_count == 1 @pytest.mark.asyncio -async def test_chain_raises_for_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_chain_raises_for_unexpected_error(get_token_method): """the chain should not continue after an unexpected error (i.e. anything but CredentialUnavailableError)""" async def credential_unavailable(message="it didn't work", **_): @@ -88,36 +107,53 @@ async def test_chain_raises_for_unexpected_error(): expected_message = "it can't be done" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), - Mock(spec_set=["get_token"], get_token=Mock(side_effect=ValueError(expected_message))), - Mock(spec_set=["get_token"], get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("**", 42)))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=ValueError(expected_message)), + get_token_info=Mock(side_effect=ValueError(expected_message)), + ), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("**", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("**", 42))), + ), ] with pytest.raises(ClientAuthenticationError) as ex: - await ChainedTokenCredential(*credentials).get_token("scope") + await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") assert expected_message in ex.value.message - assert credentials[-1].get_token.call_count == 0 + assert getattr(credentials[-1], get_token_method).call_count == 0 @pytest.mark.asyncio -async def test_returns_first_token(): - expected_token = Mock() - first_credential = Mock(spec_set=["get_token"], get_token=wrap_in_future(lambda _, **__: expected_token)) - second_credential = Mock(spec_set=["get_token"], get_token=Mock()) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_returns_first_token(get_token_method): + access_token = "expected_token" + first_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=wrap_in_future(lambda _, **__: AccessToken(access_token, 42)), + get_token_info=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42)), + ) + second_credential = Mock(spec_set=["get_token", "get_token_info"], get_token=Mock(), get_token_info=Mock()) aggregate = ChainedTokenCredential(first_credential, second_credential) - credential = await aggregate.get_token("scope") + token = await getattr(aggregate, get_token_method)("scope") - assert credential is expected_token - assert second_credential.get_token.call_count == 0 + assert token.token == access_token + assert getattr(second_credential, get_token_method).call_count == 0 @pytest.mark.asyncio -async def test_managed_identity_imds_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_managed_identity_imds_probe(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) scope = "scope" transport = async_validating_transport( requests=[ @@ -148,32 +184,123 @@ async def test_managed_identity_imds_probe(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with patch.dict("os.environ", clear=True): credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message=""))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="")), + ), ManagedIdentityCredential(transport=transport), ] - token = await ChainedTokenCredential(*credentials).get_token(scope) - assert token == expected_token + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)(scope) + assert token.token == access_token @pytest.mark.asyncio -async def test_managed_identity_failed_probe(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_managed_identity_failed_probe(get_token_method): async def credential_unavailable(message="it didn't work", **_): raise CredentialUnavailableError(message) mock_send = Mock(side_effect=Exception("timeout")) transport = Mock(send=wrap_in_future(mock_send)) - expected_token = AccessToken("**", 42) + expected_token = "***" credentials = [ - Mock(spec_set=["get_token"], get_token=Mock(wraps=credential_unavailable)), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=credential_unavailable), + get_token_info=Mock(wraps=credential_unavailable), + ), ManagedIdentityCredential(transport=transport), - Mock(spec_set=["get_token"], get_token=Mock(wraps=wrap_in_future(lambda _, **__: expected_token))), + Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken(expected_token, 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(expected_token, 42))), + ), ] with patch.dict("os.environ", clear=True): - token = await ChainedTokenCredential(*credentials).get_token("scope") + token = await getattr(ChainedTokenCredential(*credentials), get_token_method)("scope") - assert token is expected_token + assert token.token == expected_token # ManagedIdentityCredential should be tried and skipped with the last credential in the chain # being used. - assert credentials[-1].get_token.call_count == 1 + assert getattr(credentials[-1], get_token_method).call_count == 1 + + +@pytest.mark.asyncio +async def test_credentials_with_no_get_token_info(): + """ChainedTokenCredential should work with credentials that don't implement get_token_info.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken(access_token, 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("bar", 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token_info("scope") + assert token_info.token == access_token + + +@pytest.mark.asyncio +async def test_credentials_with_no_get_token(): + """ChainedTokenCredential should work with credentials that only implement get_token_info.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo("bar", 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token("scope") + assert token_info.token == access_token + + +@pytest.mark.asyncio +async def test_credentials_with_pop_option(): + """ChainedTokenCredential should skip credentials that don't support get_token_info and the pop option is set.""" + + async def credential_unavailable(message="it didn't work", **_): + raise CredentialUnavailableError(message) + + access_token = "****" + credential1 = Mock( + spec_set=["get_token_info"], + get_token_info=Mock(wraps=credential_unavailable), + ) + credential2 = Mock( + spec_set=["get_token"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("foo", 42))), + ) + credential3 = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(wraps=wrap_in_future(lambda _, **__: AccessToken("bar", 42))), + get_token_info=Mock(wraps=wrap_in_future(lambda _, **__: AccessTokenInfo(access_token, 42))), + ) + chain = ChainedTokenCredential(credential1, credential2, credential3) # type: ignore + token_info = await chain.get_token_info("scope", options={"pop": True}) # type: ignore + assert token_info.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_cli_credential.py b/sdk/identity/azure-identity/tests/test_cli_credential.py index 0b688df36fa..2ed7c5657f7 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from datetime import datetime +from itertools import product import json import re @@ -14,7 +15,7 @@ from azure.core.exceptions import ClientAuthenticationError import subprocess import pytest -from helpers import mock, INVALID_CHARACTERS +from helpers import mock, INVALID_CHARACTERS, GET_TOKEN_METHODS CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" @@ -35,21 +36,24 @@ def raise_called_process_error(return_code, output="", cmd="...", stderr=""): return mock.Mock(side_effect=error) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzureCliCredential().get_token() + getattr(AzureCliCredential(), get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - AzureCliCredential().get_token("one scope", "and another") + getattr(AzureCliCredential(), get_token_method)("one scope", "and another") -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -57,18 +61,23 @@ def test_invalid_tenant_id(): AzureCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzureCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzureCliCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzureCliCredential().get_token("scope" + c) + getattr(AzureCliCredential(), get_token_method)("scope" + c) -def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -85,14 +94,15 @@ def test_get_token(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -def test_expires_on_used(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expires_on_used(get_token_method): """Test that 'expires_on' is preferred over 'expiresOn'.""" expires_on = 1602015811 successful_output = json.dumps( @@ -108,12 +118,13 @@ def test_expires_on_used(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.expires_on == expires_on -def test_expires_on_string(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expires_on_string(get_token_method): """Test that 'expires_on' still works if it's a string.""" expires_on = 1602015811 successful_output = json.dumps( @@ -128,85 +139,91 @@ def test_expires_on_string(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=successful_output)): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert type(token.expires_on) == int assert token.expires_on == expires_on -def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: Please run 'az login' to setup account." with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_aadsts_error(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: AADSTS70043: The refresh token has expired, Please run 'az login' to setup account." with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(42, stderr=stderr)): with pytest.raises(ClientAuthenticationError, match=stderr): - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, raise_called_process_error(1, output=output)): with pytest.raises(ClientAuthenticationError) as ex: - AzureCliCredential().get_token("scope") + getattr(AzureCliCredential(), get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess times out""" from subprocess import TimeoutExpired @@ -214,7 +231,7 @@ def test_timeout(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))) as check_output_mock: with pytest.raises(CredentialUnavailableError): - AzureCliCredential(process_timeout=42).get_token("scope") + getattr(AzureCliCredential(process_timeout=42), get_token_method)("scope") # Ensure custom timeout is passed to subprocess _, kwargs = check_output_mock.call_args @@ -222,7 +239,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_multitenant_authentication_class(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_class(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -244,17 +262,18 @@ def test_multitenant_authentication_class(): with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = AzureCliCredential().get_token("scope") + token = getattr(AzureCliCredential(), get_token_method)("scope") assert token.token == first_token - token = AzureCliCredential(tenant_id=default_tenant).get_token("scope") + token = getattr(AzureCliCredential(tenant_id=default_tenant), get_token_method)("scope") assert token.token == first_token - token = AzureCliCredential(tenant_id=second_tenant).get_token("scope") + token = getattr(AzureCliCredential(tenant_id=second_tenant), get_token_method)("scope") assert token.token == second_token -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -277,21 +296,28 @@ def test_multitenant_authentication(): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -311,9 +337,12 @@ def test_multitenant_authentication_not_allowed(): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(CHECK_OUTPUT, fake_check_output): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_cli_credential_async.py index d771dd7d77d..7865a3e6b38 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential_async.py @@ -4,6 +4,7 @@ # ------------------------------------ import asyncio from datetime import datetime +from itertools import product import json import re import sys @@ -16,7 +17,7 @@ from azure.identity._credentials.azure_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError import pytest -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_cli_credential import TEST_ERROR_OUTPUTS @@ -33,21 +34,24 @@ def mock_exec(stdout, stderr="", return_code=0): return mock.Mock(return_value=get_completed_future(process)) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzureCliCredential().get_token() + await getattr(AzureCliCredential(), get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - await AzureCliCredential().get_token("one scope", "and another") + await getattr(AzureCliCredential(), get_token_method)("one scope", "and another") -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -55,15 +59,19 @@ async def test_invalid_tenant_id(): AzureCliCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzureCliCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzureCliCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzureCliCredential().get_token("https://scope" + c) + await getattr(AzureCliCredential(), get_token_method)("https://scope" + c) async def test_close(): @@ -80,21 +88,25 @@ async def test_context_manager(): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_fallback(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_fallback(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = mock.Mock() with mock.patch("azure.identity.aio._credentials.azure_cli._SyncAzureCliCredential") as fallback: - fallback.return_value = mock.Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = mock.Mock( + spec_set=["get_token", "get_token_info"], get_token=sync_get_token, get_token_info=sync_get_token + ) with mock.patch(AzureCliCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 -async def test_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_get_token(get_token_method): """The credential should parse the CLI's output to an AccessToken""" access_token = "access token" @@ -112,14 +124,15 @@ async def test_get_token(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token assert type(token.expires_on) == int assert token.expires_on == expected_expires_on -async def test_expires_on_used(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expires_on_used(get_token_method): """Test that 'expires_on' is preferred over 'expiresOn'.""" expires_on = 1602015811 successful_output = json.dumps( @@ -136,12 +149,13 @@ async def test_expires_on_used(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.expires_on == expires_on -async def test_expires_on_string(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expires_on_string(get_token_method): """Test that 'expires_on' still works if it's a string.""" expires_on = 1602015811 successful_output = json.dumps( @@ -157,32 +171,35 @@ async def test_expires_on_string(): with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(successful_output)): credential = AzureCliCredential() - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert type(token.expires_on) == int assert token.expires_on == expires_on -async def test_cli_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when the CLI isn't installed""" with mock.patch("shutil.which", return_value=None): with pytest.raises(CredentialUnavailableError, match=CLI_NOT_FOUND): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(side_effect=OSError())): with pytest.raises(CredentialUnavailableError): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_not_logged_in(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_not_logged_in(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: Please run 'az login' to setup account." @@ -190,10 +207,11 @@ async def test_not_logged_in(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(CredentialUnavailableError, match=NOT_LOGGED_IN): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_aadsts_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_aadsts_error(get_token_method): """When the CLI isn't logged in, the credential should raise CredentialUnavailableError""" stderr = "ERROR: AADSTS70043: The refresh token has expired, Please run 'az login' to setup account." @@ -201,10 +219,11 @@ async def test_aadsts_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=1)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """When the CLI returns an unexpected error, the credential should raise an error containing the CLI's output""" stderr = "something went wrong" @@ -212,50 +231,52 @@ async def test_unexpected_error(): with mock.patch(SUBPROCESS_EXEC, mock_exec("", stderr, return_code=42)): with pytest.raises(ClientAuthenticationError, match=stderr): credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_parsing_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +async def test_parsing_error_does_not_expose_token(output, get_token_method): """Errors during CLI output parsing shouldn't expose access tokens in that output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -@pytest.mark.parametrize("output", TEST_ERROR_OUTPUTS) -async def test_subprocess_error_does_not_expose_token(output): +@pytest.mark.parametrize("output,get_token_method", product(TEST_ERROR_OUTPUTS, GET_TOKEN_METHODS)) +async def test_subprocess_error_does_not_expose_token(output, get_token_method): """Errors from the subprocess shouldn't expose access tokens in CLI output""" with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock_exec(output, return_code=1)): with pytest.raises(ClientAuthenticationError) as ex: credential = AzureCliCredential() - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert "secret value" not in str(ex.value) assert "secret value" not in repr(ex.value) -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = mock.Mock(communicate=mock.Mock(side_effect=asyncio.TimeoutError), returncode=None) with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, mock.Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzureCliCredential().get_token("scope") + await getattr(AzureCliCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): default_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -279,21 +300,28 @@ async def test_multitenant_authentication(): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -314,9 +342,12 @@ async def test_multitenant_authentication_not_allowed(): credential = AzureCliCredential() with mock.patch("shutil.which", return_value="az"): with mock.patch(SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py index 30586a55777..d1c1fb2beaf 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py @@ -7,8 +7,9 @@ from unittest.mock import MagicMock, Mock, patch from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION from azure.identity import ClientAssertionCredential, TokenCachePersistenceOptions +import pytest -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS def test_init_with_kwargs(): @@ -40,7 +41,8 @@ def test_context_manager(): assert transport.__exit__.called -def test_token_cache_persistence(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistence(get_token_method): """The credential should use a persistent cache if cache_persistence_options are configured.""" access_token = "foo" @@ -72,12 +74,15 @@ def test_token_cache_persistence(): assert credential._client._cache is None assert credential._client._cae_cache is None - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None - token = credential.get_token(scope, enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(scope, **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py index 18a2dcb4b57..ddf54bf35be 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py @@ -10,7 +10,7 @@ from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION from azure.identity import TokenCachePersistenceOptions from azure.identity.aio import ClientAssertionCredential -from helpers import build_aad_response, mock_response +from helpers import build_aad_response, mock_response, GET_TOKEN_METHODS def test_init_with_kwargs(): @@ -45,7 +45,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_token_cache_persistence(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistence(get_token_method): """The credential should use a persistent cache if cache_persistence_options are configured.""" access_token = "foo" @@ -77,12 +78,15 @@ async def test_token_cache_persistence(): assert credential._client._cache is None assert credential._client._cae_cache is None - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None - token = await credential.get_token(scope, enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index 7b9b9a195ef..ae778605ad2 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -2,6 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product +from urllib.parse import urlparse +from unittest.mock import Mock, patch + from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import ClientSecretCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority @@ -10,7 +14,6 @@ from azure.identity._internal.user_agent import USER_AGENT from msal import TokenCache import msal import pytest -from urllib.parse import urlparse from helpers import ( build_aad_response, @@ -18,11 +21,10 @@ from helpers import ( get_discovery_response, id_token_claims, mock_response, - msal_validating_transport, new_msal_validating_transport, Request, + GET_TOKEN_METHODS, ) -from unittest.mock import Mock, patch def test_tenant_id_validation(): @@ -38,15 +40,17 @@ def test_tenant_id_validation(): ClientSecretCredential(tenant, "client-id", "secret") -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ClientSecretCredential("tenant-id", "client-id", "client-secret") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = new_msal_validating_transport( @@ -57,12 +61,13 @@ def test_policies_configurable(): "tenant-id", "client-id", "client-secret", policies=[ContentDecodePolicy(), policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = new_msal_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -70,10 +75,11 @@ def test_user_agent(): credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_client_secret_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_secret_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -85,13 +91,15 @@ def test_client_secret_credential(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token))], ) - token = ClientSecretCredential(tenant_id, client_id, secret, transport=transport).get_token("scope") + token = getattr(ClientSecretCredential(tenant_id, client_id, secret, transport=transport), get_token_method)( + "scope" + ) assert token.token == access_token -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -106,7 +114,7 @@ def test_authority(authority): credential = ClientSecretCredential(tenant_id, "client-id", "secret", authority=authority) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -117,14 +125,15 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = ClientSecretCredential(tenant_id, "client-id", "secret") with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args assert kwargs["authority"] == expected_authority -def test_regional_authority(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_regional_authority(get_token_method): """the credential should configure MSAL with a regional authority specified via kwarg or environment variable""" mock_confidential_client = Mock( @@ -138,7 +147,7 @@ def test_regional_authority(): with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: region.value}, clear=True): credential = ClientSecretCredential("tenant", "client-id", "secret") with patch("msal.ConfidentialClientApplication", mock_confidential_client): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_confidential_client.call_count == 1 _, kwargs = mock_confidential_client.call_args @@ -148,7 +157,8 @@ def test_regional_authority(): assert kwargs["azure_region"] == region.value -def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(get_token_method): """the credential should use a persistent cache if cache_persistence_options are configured""" access_token = "foo token" @@ -176,18 +186,22 @@ def test_token_cache_persistent(): assert credential._cache is None assert credential._cae_cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None -def test_token_cache_memory(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_memory(get_token_method): """The credential should default to in-memory cache if no persistence options are provided.""" access_token = "foo token" @@ -205,18 +219,22 @@ def test_token_cache_memory(): credential = ClientSecretCredential("tenant", "client-id", "secret", transport=Mock(send=send)) assert credential._cache is None - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token assert isinstance(credential._cache, TokenCache) assert credential._cae_cache is None assert not load_persistent_cache.called - token = credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert isinstance(credential._cae_cache, TokenCache) assert not load_persistent_cache.called -def test_cache_multiple_clients(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cache_multiple_clients(get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -249,13 +267,13 @@ def test_cache_multiple_clients(): # A caches a token scope = "scope" - token_a = credential_a.get_token(scope) + token_a = getattr(credential_a, get_token_method)(scope) assert mock_cache_loader.call_count == 1 assert token_a.token == access_token_a assert transport_a.send.call_count == 2 # one MSAL discovery request, one token request # B should get a different token for the same scope - token_b = credential_b.get_token(scope) + token_b = getattr(credential_b, get_token_method)(scope) assert mock_cache_loader.call_count == 2 assert token_b.token == access_token_b assert transport_b.send.call_count == 2 @@ -263,7 +281,8 @@ def test_cache_multiple_clients(): assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -286,21 +305,28 @@ def test_multitenant_authentication(): credential = ClientSecretCredential( first_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_live_multitenant_authentication(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_live_multitenant_authentication(live_service_principal, get_token_method): # first create a credential with a non-existent tenant credential = ClientSecretCredential( "...", @@ -309,12 +335,15 @@ def test_live_multitenant_authentication(live_service_principal): additionally_allowed_tenants=["*"], ) # then get a valid token for an actual tenant - token = credential.get_token("https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"]) + token = getattr(credential, get_token_method)( + "https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"] + ) assert token.token assert token.expires_on -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -332,18 +361,25 @@ def test_multitenant_authentication_not_allowed(): credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token -def test_client_capabilities(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_capabilities(get_token_method): """The credential should configure MSAL for capability only if enable_cae is passed in.""" transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) @@ -362,7 +398,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -378,7 +415,10 @@ def test_claims_challenge(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args @@ -389,7 +429,8 @@ def test_claims_challenge(): assert kwargs["claims_challenge"] == expected_claims -def test_msal_kwargs_filtered(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_msal_kwargs_filtered(get_token_method): msal_acquire_token_result = dict( build_aad_response(access_token="**", id_token=build_id_token()), id_token_claims=id_token_claims("issuer", "subject", "audience", upn="upn"), @@ -402,10 +443,12 @@ def test_msal_kwargs_filtered(): msal_app.acquire_token_silent_with_error.return_value = None msal_app.acquire_token_for_client.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims, correlation_id="foo", enable_cae=True) + kwargs = {"claims": expected_claims, "enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 _, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims - assert kwargs["correlation_id"] == "foo" assert "enable_cae" not in kwargs diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 0b4a91ac6a1..d8044ec5269 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import time +from itertools import product from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -15,7 +16,7 @@ from azure.identity.aio import ClientSecretCredential from msal import TokenCache import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future @@ -33,12 +34,13 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ClientSecretCredential("tenant-id", "client-id", "client-secret") with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -52,7 +54,8 @@ async def test_close(): @pytest.mark.asyncio -async def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_context_manager(get_token_method): transport = AsyncMockTransport() credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) @@ -64,7 +67,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -77,13 +81,14 @@ async def test_policies_configurable(): "tenant-id", "client-id", "client-secret", policies=[ContentDecodePolicy(), policy], transport=Mock(send=send) ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -91,11 +96,12 @@ async def test_user_agent(): credential = ClientSecretCredential("tenant-id", "client-id", "client-secret", transport=transport) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_client_secret_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -115,17 +121,18 @@ async def test_client_secret_credential(): ], ) - token = await ClientSecretCredential( - tenant_id=tenant_id, client_id=client_id, client_secret=secret, transport=transport - ).get_token("scope") + token = await getattr( + ClientSecretCredential(tenant_id=tenant_id, client_id=client_id, client_secret=secret, transport=transport), + get_token_method, + )("scope") # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere assert token.token == access_token @pytest.mark.asyncio -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -async def test_request_url(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +async def test_request_url(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -143,22 +150,23 @@ async def test_request_url(authority): credential = ClientSecretCredential( tenant_id, "client-id", "secret", transport=Mock(send=mock_send), authority=authority ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = ClientSecretCredential(tenant_id, "client-id", "secret", transport=Mock(send=mock_send)) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert token.token == access_token @pytest.mark.asyncio -async def test_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache(get_token_method): expired = "this token's expired" now = int(time.time()) expired_on = now - 3600 - expired_token = AccessToken(expired, expired_on) + expired_token = expired token_payload = { "access_token": expired, "expires_in": 0, @@ -175,22 +183,22 @@ async def test_cache(): # get_token initially returns the expired token because the credential # doesn't check whether tokens it receives from the service have expired - token = await credential.get_token(scope) - assert token == expired_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expired_token access_token = "new token" token_payload["access_token"] = access_token token_payload["expires_on"] = now + 3600 - valid_token = AccessToken(access_token, now + 3600) # second call should observe the cached token has expired, and request another - token = await credential.get_token(scope) - assert token == valid_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == access_token assert mock_send.call_count == 2 @pytest.mark.asyncio -async def test_token_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache(get_token_method): """the credential should default to an in memory cache, and optionally use a persistent cache""" access_token = "token" @@ -208,20 +216,24 @@ async def test_token_cache(): assert mock_token_cache.call_count == 0 assert not load_persistent_cache.called - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_token_cache.call_count == 1 assert load_persistent_cache.call_count == 0 assert credential._client._cache is not None assert credential._client._cae_cache is None - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert mock_token_cache.call_count == 2 assert load_persistent_cache.call_count == 0 assert credential._client._cae_cache is not None @pytest.mark.asyncio -async def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_cache_persistent(get_token_method): """the credential should use persistent cache if passed in cache options.""" access_token = "token" @@ -241,14 +253,17 @@ async def test_token_cache_persistent(): cache_persistence_options=TokenCachePersistenceOptions(), transport=transport, ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._client._cache is not None assert credential._client._cae_cache is None args, _ = load_persistent_cache.call_args assert args[1] is False - await credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._client._cae_cache is not None args, _ = load_persistent_cache.call_args @@ -256,7 +271,8 @@ async def test_token_cache_persistent(): @pytest.mark.asyncio -async def test_cache_multiple_clients(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache_multiple_clients(get_token_method): """the credential shouldn't use tokens issued to other service principals""" access_token_a = "token a" @@ -291,7 +307,7 @@ async def test_cache_multiple_clients(): # A caches a token scope = "scope" - token_a = await credential_a.get_token(scope) + token_a = await getattr(credential_a, get_token_method)(scope) assert token_a.token == access_token_a assert transport_a.send.call_count == 1 assert mock_cache_loader.call_count == 1 @@ -299,7 +315,7 @@ async def test_cache_multiple_clients(): assert args[1] is False # B should get a different token for the same scope - token_b = await credential_b.get_token(scope) + token_b = await getattr(credential_b, get_token_method)(scope) assert token_b.token == access_token_b assert transport_b.send.call_count == 1 assert mock_cache_loader.call_count == 2 @@ -308,7 +324,8 @@ async def test_cache_multiple_clients(): @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -329,22 +346,29 @@ async def test_multitenant_authentication(): credential = ClientSecretCredential( first_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio -async def test_live_multitenant_authentication(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_live_multitenant_authentication(live_service_principal, get_token_method): # first create a credential with a non-existent tenant credential = ClientSecretCredential( "...", @@ -352,16 +376,18 @@ async def test_live_multitenant_authentication(live_service_principal): live_service_principal["client_secret"], additionally_allowed_tenants=["*"], ) + kwargs = {"tenant_id": live_service_principal["tenant_id"]} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} # then get a valid token for an actual tenant - token = await credential.get_token( - "https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"] - ) + token = await getattr(credential, get_token_method)("https://vault.azure.net/.default", **kwargs) assert token.token assert token.expires_on @pytest.mark.asyncio -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -378,15 +404,21 @@ async def test_multitenant_authentication_not_allowed(): expected_tenant, "client-id", "secret", transport=Mock(send=send), additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token * 2 with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_context_manager.py b/sdk/identity/azure-identity/tests/test_context_manager.py index e2d98c80eb0..ba97ef28fb9 100644 --- a/sdk/identity/azure-identity/tests/test_context_manager.py +++ b/sdk/identity/azure-identity/tests/test_context_manager.py @@ -2,10 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -try: - from unittest.mock import MagicMock, patch -except ImportError: - from mock import MagicMock, patch # type: ignore +from unittest.mock import MagicMock, patch from azure.identity._credentials.application import AzureApplicationCredential from azure.identity import ( diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index bbcff2f013e..b77bfbcbad2 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -4,7 +4,7 @@ # ------------------------------------ import os -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import ( AzureCliCredential, AzureDeveloperCliCredential, @@ -22,7 +22,7 @@ from azure.identity._credentials.managed_identity import ManagedIdentityCredenti import pytest from urllib.parse import urlparse -from helpers import mock_response, Request, validating_transport +from helpers import mock_response, Request, validating_transport, GET_TOKEN_METHODS from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache from unittest.mock import MagicMock, Mock, patch @@ -50,28 +50,36 @@ def test_context_manager(): assert transport.__exit__.called -def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_iterates_only_once(get_token_method): """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), + ) + successful_credential = Mock( + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=AccessToken("***", 42)), + get_token_info=Mock(return_value=AccessTokenInfo("***", 42)), ) - successful_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42))) credential = DefaultAzureCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - credential.get_token("scope") - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + getattr(credential, get_token_method)("scope") + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -174,7 +182,8 @@ def test_exclude_options(): assert actual - default == {InteractiveBrowserCredential} -def test_shared_cache_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_shared_cache_tenant_id(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -195,14 +204,14 @@ def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_username makes no difference credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b, shared_cache_username=upn ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_tenant_id should prevail over AZURE_TENANT_ID @@ -210,17 +219,18 @@ def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_TENANT_ID should be used when shared_cache_tenant_id isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_TENANT_ID: tenant_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_shared_cache_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_shared_cache_username(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -240,14 +250,14 @@ def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_tenant_id makes no difference credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_tenant_id=tenant_id, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_username should prevail over AZURE_USERNAME @@ -255,13 +265,13 @@ def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_USERNAME should be used when shared_cache_username isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_USERNAME: upn_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 57b466ebb30..18a942b705a 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -6,7 +6,7 @@ import os from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.identity import CredentialUnavailableError from azure.identity.aio import ( AzurePowerShellCredential, @@ -20,36 +20,42 @@ from azure.identity.aio import ( from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, get_completed_future, wrap_in_future from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache @pytest.mark.asyncio -async def test_iterates_only_once(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_iterates_only_once(get_token_method): """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" unavailable_credential = Mock( - spec_set=["get_token"], get_token=Mock(side_effect=CredentialUnavailableError(message="...")) + spec_set=["get_token", "get_token_info"], + get_token=Mock(side_effect=CredentialUnavailableError(message="...")), + get_token_info=Mock(side_effect=CredentialUnavailableError(message="...")), ) successful_credential = Mock( - spec_set=["get_token"], get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))) + spec_set=["get_token", "get_token_info"], + get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))), + get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo("***", 42))), ) credential = DefaultAzureCredential() - credential.credentials = [ + credential.credentials = ( unavailable_credential, successful_credential, Mock( - spec_set=["get_token"], + spec_set=["get_token", "get_token_info"], get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), + get_token_info=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token")), ), - ] + ) for n in range(3): - await credential.get_token("scope") - assert unavailable_credential.get_token.call_count == 1 - assert successful_credential.get_token.call_count == n + 1 + await getattr(credential, get_token_method)("scope") + assert getattr(unavailable_credential, get_token_method).call_count == 1 + assert getattr(successful_credential, get_token_method).call_count == n + 1 @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -145,7 +151,8 @@ def test_exclude_options(): @pytest.mark.asyncio -async def test_shared_cache_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_shared_cache_tenant_id(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -166,14 +173,14 @@ async def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # redundantly specifying shared_cache_username makes no difference credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b, shared_cache_username=upn ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_tenant_id should prevail over AZURE_TENANT_ID @@ -181,18 +188,19 @@ async def test_shared_cache_tenant_id(): credential = get_credential_for_shared_cache_test( refresh_token_b, expected_access_token, cache, shared_cache_tenant_id=tenant_b ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_TENANT_ID should be used when shared_cache_tenant_id isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_TENANT_ID: tenant_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_shared_cache_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_shared_cache_username(get_token_method): expected_access_token = "expected-access-token" refresh_token_a = "refresh-token-a" refresh_token_b = "refresh-token-b" @@ -212,7 +220,7 @@ async def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # shared_cache_username should prevail over AZURE_USERNAME @@ -220,13 +228,13 @@ async def test_shared_cache_username(): credential = get_credential_for_shared_cache_test( refresh_token_a, expected_access_token, cache, shared_cache_username=upn_a ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # AZURE_USERNAME should be used when shared_cache_username isn't specified with patch("os.environ", {EnvironmentVariables.AZURE_USERNAME: upn_b}): credential = get_credential_for_shared_cache_test(refresh_token_b, expected_access_token, cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @@ -305,7 +313,7 @@ def test_process_timeout(): assert kwargs["process_timeout"] == timeout -def test_process_timeout(): +def test_process_timeout_default(): """the credential should allow configuring a process timeout for Azure CLI and PowerShell by kwarg""" with patch(DefaultAzureCredential.__module__ + ".AzureCliCredential") as mock_cli_credential: diff --git a/sdk/identity/azure-identity/tests/test_device_code_credential.py b/sdk/identity/azure-identity/tests/test_device_code_credential.py index d8ee36d4f59..c1bebe13de2 100644 --- a/sdk/identity/azure-identity/tests/test_device_code_credential.py +++ b/sdk/identity/azure-identity/tests/test_device_code_credential.py @@ -19,6 +19,7 @@ from helpers import ( mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) @@ -35,15 +36,17 @@ def test_tenant_id_validation(): DeviceCodeCredential(tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = DeviceCodeCredential("client_id") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_authenticate(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authenticate(get_token_method): client_id = "client-id" environment = "localhost" issuer = "https://" + environment @@ -92,21 +95,23 @@ def test_authenticate(): assert record.username == username # credential should have a cached access token for the scope used in authenticate - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When configured for strict silent auth, the credential should raise when silent auth fails""" transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) credential = DeviceCodeCredential("client-id", disable_automatic_authentication=True, transport=transport) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) client_id = "client-id" @@ -135,12 +140,13 @@ def test_policies_configurable(): client_id=client_id, prompt_callback=Mock(), policies=[policy], transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): client_id = "client-id" transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], @@ -164,10 +170,11 @@ def test_user_agent(): credential = DeviceCodeCredential(client_id=client_id, prompt_callback=Mock(), transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_device_code_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_device_code_credential(get_token_method): client_id = "client-id" expected_token = "access-token" user_code = "user-code" @@ -210,7 +217,7 @@ def test_device_code_credential(): ) now = datetime.datetime.now(datetime.timezone.utc) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token # prompt_callback should have been called as documented @@ -226,7 +233,8 @@ def test_device_code_credential(): assert expires_on - now >= datetime.timedelta(seconds=expires_in) -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): client_id = "client-id" expected_token = "access-token" user_code = "user-code" @@ -269,11 +277,15 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): flow = {"expires_in": 1800, "message": "foo"} with patch.object(DeviceCodeCredential, "_get_app") as get_app: msal_app = get_app() @@ -282,7 +294,7 @@ def test_timeout(): credential = DeviceCodeCredential(client_id="_", timeout=1, disable_instance_discovery=True) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert "timed out" in ex.value.message.lower() msal_app.acquire_token_by_device_flow.assert_called_once_with(flow, exit_condition=ANY, claims_challenge=None) @@ -306,7 +318,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -328,7 +341,10 @@ def test_claims_challenge(): args, kwargs = msal_app.acquire_token_by_device_flow.call_args assert kwargs["claims_challenge"] == expected_claims - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_by_device_flow.call_count == 2 args, kwargs = msal_app.acquire_token_by_device_flow.call_args @@ -336,7 +352,11 @@ def test_claims_challenge(): msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_environment_credential.py b/sdk/identity/azure-identity/tests/test_environment_credential.py index 1bce68eb34b..b9e11908e63 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential.py @@ -9,7 +9,7 @@ from azure.identity import CredentialUnavailableError, EnvironmentCredential from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock +from helpers import mock, GET_TOKEN_METHODS ALL_VARIABLES = { @@ -20,17 +20,18 @@ ALL_VARIABLES = { } -def test_incomplete_configuration(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_incomplete_configuration(get_token_method): """get_token should raise CredentialUnavailableError for incomplete configuration.""" with mock.patch.dict(os.environ, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - EnvironmentCredential().get_token("scope") + getattr(EnvironmentCredential(), get_token_method)("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set with mock.patch.dict(os.environ, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - EnvironmentCredential().get_token("scope") + getattr(EnvironmentCredential(), get_token_method)("scope") @pytest.mark.parametrize( diff --git a/sdk/identity/azure-identity/tests/test_environment_credential_async.py b/sdk/identity/azure-identity/tests/test_environment_credential_async.py index 480cc5ac202..60dbfc07625 100644 --- a/sdk/identity/azure-identity/tests/test_environment_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_environment_credential_async.py @@ -10,7 +10,7 @@ from azure.identity.aio import EnvironmentCredential from azure.identity._constants import EnvironmentVariables import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_environment_credential import ALL_VARIABLES @@ -56,17 +56,18 @@ async def test_context_manager_incomplete_configuration(): @pytest.mark.asyncio -async def test_incomplete_configuration(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_incomplete_configuration(get_token_method): """get_token should raise CredentialUnavailableError for incomplete configuration.""" with mock.patch.dict(ENVIRON, {}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await EnvironmentCredential().get_token("scope") + await getattr(EnvironmentCredential(), get_token_method)("scope") for a, b in itertools.combinations(ALL_VARIABLES, 2): # all credentials require at least 3 variables set with mock.patch.dict(ENVIRON, {a: "a", b: "b"}, clear=True): with pytest.raises(CredentialUnavailableError) as ex: - await EnvironmentCredential().get_token("scope") + await getattr(EnvironmentCredential(), get_token_method)("scope") @pytest.mark.parametrize( @@ -169,7 +170,8 @@ def test_certificate_with_password_configuration(): @pytest.mark.asyncio -async def test_client_secret_environment_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_environment_credential(get_token_method): client_id = "fake-client-id" secret = "fake-client-secret" tenant_id = "fake-tenant-id" @@ -195,6 +197,6 @@ async def test_client_secret_environment_credential(): EnvironmentVariables.AZURE_TENANT_ID: tenant_id, } with mock.patch.dict(ENVIRON, environment, clear=True): - token = await EnvironmentCredential(transport=transport).get_token("scope") + token = await getattr(EnvironmentCredential(transport=transport), get_token_method)("scope") assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin.py b/sdk/identity/azure-identity/tests/test_get_token_mixin.py index e3326b8f5cc..e0c877cb9a7 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin.py @@ -5,15 +5,17 @@ import time from unittest import mock -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo import pytest from azure.identity._constants import DEFAULT_REFRESH_OFFSET from azure.identity._internal.get_token_mixin import GetTokenMixin +from helpers import GET_TOKEN_METHODS + class MockCredential(GetTokenMixin): - NEW_TOKEN = AccessToken("new token", 42) + NEW_TOKEN = AccessTokenInfo("new token", 42) def __init__(self, cached_token=None): super(MockCredential, self).__init__() @@ -29,85 +31,102 @@ class MockCredential(GetTokenMixin): def get_token(self, *_, **__): return super(MockCredential, self).get_token(*_, **__) + def get_token_info(self, *_, **__): + return super(MockCredential, self).get_token_info(*_, **__) + CACHED_TOKEN = "cached token" SCOPE = "scope" -def test_no_cached_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_cached_token(get_token_method): """When it has no token cached, a credential should request one every time get_token is called""" credential = MockCredential() - token = credential.get_token(SCOPE) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): credential = MockCredential() - token = credential.get_token(SCOPE, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(SCOPE, **kwargs) assert token.token == MockCredential.NEW_TOKEN.token -def test_token_acquisition_failure(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_acquisition_failure(get_token_method): """When the credential has no token cached, every get_token call should prompt a token request""" credential = MockCredential() credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): with pytest.raises(Exception): - credential.get_token(SCOPE) + getattr(credential, get_token_method)(SCOPE) assert credential.request_token.call_count == i + 1 credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) -def test_expired_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_expired_token(get_token_method): """A credential should request a token when it has an expired token cached""" - now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now - 1)) - token = credential.get_token(SCOPE) + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1)) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_cached_token_outside_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cached_token_outside_refresh_window(get_token_method): """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET + 1)) - token = credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1)) + ) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert credential.request_token.call_count == 0 assert token.token == CACHED_TOKEN -def test_cached_token_within_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cached_token_within_refresh_window(get_token_method): """A credential should request a new token when its cached one is within the refresh window""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET - 1)) - token = credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + ) + token = getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -def test_retry_delay(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_retry_delay(get_token_method): """A credential should wait between requests when trying to refresh a token""" now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1))) # the credential should swallow exceptions during proactive refresh attempts credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): - token = credential.get_token(SCOPE) + token = getattr(credential, get_token_method)(SCOPE) assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py index 3db42219071..562ac58383d 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py @@ -5,17 +5,19 @@ import time from unittest import mock -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessTokenInfo import pytest from azure.identity._constants import DEFAULT_REFRESH_OFFSET from azure.identity.aio._internal.get_token_mixin import GetTokenMixin +from helpers import GET_TOKEN_METHODS + pytestmark = pytest.mark.asyncio class MockCredential(GetTokenMixin): - NEW_TOKEN = AccessToken("new token", 42) + NEW_TOKEN = AccessTokenInfo("new token", 42) def __init__(self, cached_token=None): super(MockCredential, self).__init__() @@ -32,85 +34,102 @@ class MockCredential(GetTokenMixin): async def get_token(self, *_, **__): return await super().get_token(*_, **__) + async def get_token_info(self, *_, **__): + return await super().get_token_info(*_, **__) + CACHED_TOKEN = "cached token" SCOPE = "scope" -async def test_no_cached_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_cached_token(get_token_method): """When it has no token cached, a credential should request one every time get_token is called""" credential = MockCredential() - token = await credential.get_token(SCOPE) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): credential = MockCredential() - token = await credential.get_token(SCOPE, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(SCOPE, **kwargs) assert token.token == MockCredential.NEW_TOKEN.token -async def test_token_acquisition_failure(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_acquisition_failure(get_token_method): """When the credential has no token cached, every get_token call should prompt a token request""" credential = MockCredential() credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): with pytest.raises(Exception): - await credential.get_token(SCOPE) + await getattr(credential, get_token_method)(SCOPE) assert credential.request_token.call_count == i + 1 credential.request_token.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) -async def test_expired_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_expired_token(get_token_method): """A credential should request a token when it has an expired token cached""" - now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now - 1)) - token = await credential.get_token(SCOPE) + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now - 1)) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_cached_token_outside_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cached_token_outside_refresh_window(get_token_method): """A credential shouldn't request a new token when it has a cached one with sufficient validity remaining""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET + 1)) - token = await credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET + 1)) + ) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert credential.request_token.call_count == 0 assert token.token == CACHED_TOKEN -async def test_cached_token_within_refresh_window(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cached_token_within_refresh_window(get_token_method): """A credential should request a new token when its cached one is within the refresh window""" - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, time.time() + DEFAULT_REFRESH_OFFSET - 1)) - token = await credential.get_token(SCOPE) + credential = MockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) + ) + token = await getattr(credential, get_token_method)(SCOPE) credential.acquire_token_silently.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) assert token.token == MockCredential.NEW_TOKEN.token -async def test_retry_delay(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_retry_delay(get_token_method): """A credential should wait between requests when trying to refresh a token""" now = time.time() - credential = MockCredential(cached_token=AccessToken(CACHED_TOKEN, now + DEFAULT_REFRESH_OFFSET - 1)) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, int(now + DEFAULT_REFRESH_OFFSET - 1))) # the credential should swallow exceptions during proactive refresh attempts credential.request_token = mock.Mock(side_effect=Exception("whoops")) for i in range(4): - token = await credential.get_token(SCOPE) + token = await getattr(credential, get_token_method)(SCOPE) assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) diff --git a/sdk/identity/azure-identity/tests/test_imds_credential.py b/sdk/identity/azure-identity/tests/test_imds_credential.py index 3bced605448..27422549276 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential.py @@ -2,50 +2,47 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import json +from itertools import product import time -from devtools_testutils import recorded_by_proxy -from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError - from azure.identity import CredentialUnavailableError -from azure.identity._constants import EnvironmentVariables -from azure.identity._credentials.imds import IMDS_TOKEN_PATH, ImdsCredential, IMDS_AUTHORITY, PIPELINE_SETTINGS -from azure.identity._internal.user_agent import USER_AGENT +from azure.identity._credentials.imds import IMDS_TOKEN_PATH, ImdsCredential, IMDS_AUTHORITY from azure.identity._internal.utils import within_credential_chain import pytest -from helpers import mock, mock_response, Request, validating_transport +from helpers import mock, mock_response, Request, validating_transport, GET_TOKEN_METHODS from recorded_test_case import RecordedTestCase -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ImdsCredential() with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" credential = ImdsCredential() with pytest.raises(ValueError): - credential.get_token("one scope", "and another") + getattr(credential, get_token_method)("one scope", "and another") -def test_identity_not_available(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_identity_not_available(get_token_method): """The credential should raise CredentialUnavailableError when the endpoint responds 400 to a token request""" transport = validating_transport(requests=[Request()], responses=[mock_response(status_code=400, json_payload={})]) credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -@pytest.mark.parametrize("error_ending", ("network", "host", "foo")) -def test_imds_request_failure_docker_desktop(error_ending): +@pytest.mark.parametrize("error_ending,get_token_method", product(("network", "host", "foo"), GET_TOKEN_METHODS)) +def test_imds_request_failure_docker_desktop(error_ending, get_token_method): """The credential should raise CredentialUnavailableError when a 403 with a specific message is received""" error_message = ( @@ -57,47 +54,55 @@ def test_imds_request_failure_docker_desktop(error_ending): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert error_message in ex.value.message @pytest.mark.usefixtures("record_imds_test") class TestImds(RecordedTestCase): - @recorded_by_proxy - def test_system_assigned(self): + + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_system_assigned(self, recorded_test, get_token_method): credential = ImdsCredential() - token = credential.get_token(self.scope) + token = getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) - @recorded_by_proxy - def test_system_assigned_tenant_id(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_system_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential() - token = credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @recorded_by_proxy - def test_user_assigned(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_user_assigned(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = credential.get_token(self.scope) + token = getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.usefixtures("user_assigned_identity_client_id") - @recorded_by_proxy - def test_user_assigned_tenant_id(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_user_assigned_tenant_id(self, recorded_test, get_token_method): credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = credential.get_token(self.scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(self.scope, **kwargs) assert token.token assert isinstance(token.expires_on, int) - def test_managed_identity_aci_probe(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_managed_identity_aci_probe(self, get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -126,7 +131,7 @@ class TestImds(RecordedTestCase): ], ) within_credential_chain.set(True) - cred = ImdsCredential(transport=transport) - token = cred.get_token(scope) - assert token.token == expected_token.token + credential = ImdsCredential(transport=transport) + token = getattr(credential, get_token_method)(scope) + assert token.token == expected_token within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/tests/test_imds_credential_async.py b/sdk/identity/azure-identity/tests/test_imds_credential_async.py index b759335fbc1..67e169203a6 100644 --- a/sdk/identity/azure-identity/tests/test_imds_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_imds_credential_async.py @@ -2,12 +2,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import json import time from unittest import mock -from devtools_testutils.aio import recorded_by_proxy_async -from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError from azure.identity._constants import EnvironmentVariables @@ -17,11 +16,10 @@ from azure.identity.aio._credentials.imds import ImdsCredential, PIPELINE_SETTIN from azure.identity._internal.utils import within_credential_chain import pytest -from helpers import mock_response, Request +from helpers import mock_response, Request, GET_TOKEN_METHODS from helpers_async import ( async_validating_transport, AsyncMockTransport, - await_test, get_completed_future, wrap_in_future, ) @@ -30,18 +28,20 @@ from recorded_test_case import RecordedTestCase pytestmark = pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = ImdsCredential() with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" credential = ImdsCredential() with pytest.raises(ValueError): - await credential.get_token("one scope", "and another") + await getattr(credential, get_token_method)("one scope", "and another") async def test_imds_close(): @@ -64,7 +64,8 @@ async def test_imds_context_manager(): assert transport.__aexit__.call_count == 1 -async def test_identity_not_available(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_identity_not_available(get_token_method): """The credential should raise CredentialUnavailableError when the endpoint responds 400 to a token request""" transport = async_validating_transport( @@ -74,10 +75,11 @@ async def test_identity_not_available(): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """The credential should raise ClientAuthenticationError when the endpoint returns an unexpected error""" error_message = "something went wrong" @@ -94,13 +96,13 @@ async def test_unexpected_error(): credential = ImdsCredential(transport=transport) with pytest.raises(ClientAuthenticationError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert error_message in ex.value.message -@pytest.mark.parametrize("error_ending", ("network", "host", "foo")) -async def test_imds_request_failure_docker_desktop(error_ending): +@pytest.mark.parametrize("error_ending,get_token_method", product(("network", "host", "foo"), GET_TOKEN_METHODS)) +async def test_imds_request_failure_docker_desktop(error_ending, get_token_method): """The credential should raise CredentialUnavailableError when a 403 with a specific message is received""" error_message = ( @@ -112,12 +114,13 @@ async def test_imds_request_failure_docker_desktop(error_ending): credential = ImdsCredential(transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert error_message in ex.value.message -async def test_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cache(get_token_method): scope = "https://foo.bar" expired = "this token's expired" now = int(time.time()) @@ -140,7 +143,7 @@ async def test_cache(): mock_send = mock.Mock(return_value=mock_response) credential = ImdsCredential(transport=mock.Mock(send=wrap_in_future(mock_send))) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expired assert mock_send.call_count == 1 @@ -149,17 +152,18 @@ async def test_cache(): token_payload["expires_on"] = int(time.time()) + 3600 token_payload["expires_in"] = 3600 token_payload["access_token"] = good_for_an_hour - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == good_for_an_hour assert mock_send.call_count == 2 # get_token should return the cached token now - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == good_for_an_hour assert mock_send.call_count == 2 -async def test_retries(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_retries(get_token_method): mock_response = mock.Mock( text=lambda encoding=None: b"{}", headers={"content-type": "application/json"}, @@ -173,20 +177,24 @@ async def test_retries(): mock_send.reset_mock() mock_response.status_code = status_code try: - await ImdsCredential( - transport=mock.Mock(send=wrap_in_future(mock_send), sleep=wrap_in_future(lambda _: None)) - ).get_token("scope") + await getattr( + ImdsCredential( + transport=mock.Mock(send=wrap_in_future(mock_send), sleep=wrap_in_future(lambda _: None)) + ), + get_token_method, + )("scope") except ClientAuthenticationError: pass # credential should have then exhausted retries for each of these status codes assert mock_send.call_count == 1 + total_retries -async def test_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_identity_config(get_token_method): param_name, param_value = "foo", "bar" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" client_id = "some-guid" @@ -215,12 +223,13 @@ async def test_identity_config(): ) credential = ImdsCredential(client_id=client_id, identity_config={param_name: param_value}, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) - assert token == expected_token + assert token.token == expected_token -async def test_imds_authority_override(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_authority_override(get_token_method): authority = "https://localhost" expected_token = "***" scope = "scope" @@ -252,52 +261,60 @@ async def test_imds_authority_override(): with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST: authority}, clear=True): credential = ImdsCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.usefixtures("record_imds_test") class TestImdsAsync(RecordedTestCase): - @await_test - @recorded_by_proxy_async - async def test_system_assigned(self): + + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_system_assigned(self, recorded_test, get_token_method): credential = ImdsCredential() - token = await credential.get_token(self.scope) - assert token.token - assert isinstance(token.expires_on, int) - - @await_test - @recorded_by_proxy_async - async def test_system_assigned_tenant_id(self): - credential = ImdsCredential() - token = await credential.get_token(self.scope, tenant_id="tenant_id") - assert token.token - assert isinstance(token.expires_on, int) - - @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test - @recorded_by_proxy_async - async def test_user_assigned(self): - credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = await credential.get_token(self.scope) - assert token.token - assert isinstance(token.expires_on, int) - - @pytest.mark.usefixtures("user_assigned_identity_client_id") - @await_test - @recorded_by_proxy_async - async def test_user_assigned_tenant_id(self): - credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) - token = await credential.get_token(self.scope, tenant_id="tenant_id") + token = await getattr(credential, get_token_method)(self.scope) assert token.token assert isinstance(token.expires_on, int) @pytest.mark.asyncio - async def test_managed_identity_aci_probe(self): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_system_assigned_tenant_id(self, recorded_test, get_token_method): + credential = ImdsCredential() + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(self.scope, **kwargs) + assert token.token + assert isinstance(token.expires_on, int) + + @pytest.mark.usefixtures("user_assigned_identity_client_id") + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_user_assigned(self, recorded_test, get_token_method): + credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) + token = await getattr(credential, get_token_method)(self.scope) + assert token.token + assert isinstance(token.expires_on, int) + + @pytest.mark.usefixtures("user_assigned_identity_client_id") + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_user_assigned_tenant_id(self, recorded_test, get_token_method): + credential = ImdsCredential(client_id=self.user_assigned_identity_client_id) + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(self.scope, **kwargs) + assert token.token + assert isinstance(token.expires_on, int) + + @pytest.mark.asyncio + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_managed_identity_aci_probe(self, get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -325,7 +342,7 @@ class TestImdsAsync(RecordedTestCase): ], ) within_credential_chain.set(True) - cred = ImdsCredential(transport=transport) - token = await cred.get_token(scope) - assert token == expected_token + credential = ImdsCredential(transport=transport) + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token within_credential_chain.set(False) diff --git a/sdk/identity/azure-identity/tests/test_initialization.py b/sdk/identity/azure-identity/tests/test_initialization.py new file mode 100644 index 00000000000..e44fe487a3d --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_initialization.py @@ -0,0 +1,74 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import sys + +from azure.core.credentials import SupportsTokenInfo, TokenCredential +from azure.identity import ( + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + DeviceCodeCredential, + EnvironmentCredential, + InteractiveBrowserCredential, + ManagedIdentityCredential, + OnBehalfOfCredential, + SharedTokenCacheCredential, + UsernamePasswordCredential, + VisualStudioCodeCredential, + WorkloadIdentityCredential, + DefaultAzureCredential, + ChainedTokenCredential, + AzureCliCredential, + AzurePowerShellCredential, + AzureDeveloperCliCredential, + AzurePipelinesCredential, +) +import pytest + + +def test_credential_is_token_credential(): + assert isinstance(AuthorizationCodeCredential, TokenCredential) + assert isinstance(CertificateCredential, TokenCredential) + assert isinstance(ClientSecretCredential, TokenCredential) + assert isinstance(DeviceCodeCredential, TokenCredential) + assert isinstance(EnvironmentCredential, TokenCredential) + assert isinstance(InteractiveBrowserCredential, TokenCredential) + assert isinstance(ManagedIdentityCredential, TokenCredential) + assert isinstance(OnBehalfOfCredential, TokenCredential) + assert isinstance(SharedTokenCacheCredential, TokenCredential) + assert isinstance(UsernamePasswordCredential, TokenCredential) + assert isinstance(VisualStudioCodeCredential, TokenCredential) + assert isinstance(WorkloadIdentityCredential, TokenCredential) + assert isinstance(DefaultAzureCredential, TokenCredential) + assert isinstance(ChainedTokenCredential, TokenCredential) + assert isinstance(AzureCliCredential, TokenCredential) + assert isinstance(AzurePowerShellCredential, TokenCredential) + assert isinstance(AzureDeveloperCliCredential, TokenCredential) + assert isinstance(AzurePipelinesCredential, TokenCredential) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses ContextManager in Python <=3.8", +) +def test_credential_is_supports_token_info(): + assert isinstance(AuthorizationCodeCredential, SupportsTokenInfo) + assert isinstance(CertificateCredential, SupportsTokenInfo) + assert isinstance(ClientSecretCredential, SupportsTokenInfo) + assert isinstance(DeviceCodeCredential, SupportsTokenInfo) + assert isinstance(EnvironmentCredential, SupportsTokenInfo) + assert isinstance(InteractiveBrowserCredential, SupportsTokenInfo) + assert isinstance(ManagedIdentityCredential, SupportsTokenInfo) + assert isinstance(OnBehalfOfCredential, SupportsTokenInfo) + assert isinstance(SharedTokenCacheCredential, SupportsTokenInfo) + assert isinstance(UsernamePasswordCredential, SupportsTokenInfo) + assert isinstance(VisualStudioCodeCredential, SupportsTokenInfo) + assert isinstance(WorkloadIdentityCredential, SupportsTokenInfo) + assert isinstance(DefaultAzureCredential, SupportsTokenInfo) + assert isinstance(ChainedTokenCredential, SupportsTokenInfo) + assert isinstance(AzureCliCredential, SupportsTokenInfo) + assert isinstance(AzurePowerShellCredential, SupportsTokenInfo) + assert isinstance(AzureDeveloperCliCredential, SupportsTokenInfo) + assert isinstance(AzurePipelinesCredential, SupportsTokenInfo) diff --git a/sdk/identity/azure-identity/tests/test_initialization_async.py b/sdk/identity/azure-identity/tests/test_initialization_async.py new file mode 100644 index 00000000000..56a5f917574 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_initialization_async.py @@ -0,0 +1,69 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import sys + +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential +from azure.identity.aio import ( + AuthorizationCodeCredential, + CertificateCredential, + ClientSecretCredential, + EnvironmentCredential, + ManagedIdentityCredential, + OnBehalfOfCredential, + SharedTokenCacheCredential, + VisualStudioCodeCredential, + WorkloadIdentityCredential, + DefaultAzureCredential, + ChainedTokenCredential, + AzureCliCredential, + AzurePowerShellCredential, + AzureDeveloperCliCredential, + AzurePipelinesCredential, +) +import pytest + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses AsyncContextManager in Python <=3.8", +) +def test_credential_is_async_token_credential(): + assert isinstance(AuthorizationCodeCredential, AsyncTokenCredential) + assert isinstance(CertificateCredential, AsyncTokenCredential) + assert isinstance(ClientSecretCredential, AsyncTokenCredential) + assert isinstance(EnvironmentCredential, AsyncTokenCredential) + assert isinstance(ManagedIdentityCredential, AsyncTokenCredential) + assert isinstance(OnBehalfOfCredential, AsyncTokenCredential) + assert isinstance(SharedTokenCacheCredential, AsyncTokenCredential) + assert isinstance(VisualStudioCodeCredential, AsyncTokenCredential) + assert isinstance(WorkloadIdentityCredential, AsyncTokenCredential) + assert isinstance(DefaultAzureCredential, AsyncTokenCredential) + assert isinstance(ChainedTokenCredential, AsyncTokenCredential) + assert isinstance(AzureCliCredential, AsyncTokenCredential) + assert isinstance(AzurePowerShellCredential, AsyncTokenCredential) + assert isinstance(AzureDeveloperCliCredential, AsyncTokenCredential) + assert isinstance(AzurePipelinesCredential, AsyncTokenCredential) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="isinstance check doesn't seem to work when the Protocol subclasses AsyncContextManager in Python <=3.8", +) +def test_credential_is_async_supports_token_info(): + assert isinstance(AuthorizationCodeCredential, AsyncSupportsTokenInfo) + assert isinstance(CertificateCredential, AsyncSupportsTokenInfo) + assert isinstance(ClientSecretCredential, AsyncSupportsTokenInfo) + assert isinstance(EnvironmentCredential, AsyncSupportsTokenInfo) + assert isinstance(ManagedIdentityCredential, AsyncSupportsTokenInfo) + assert isinstance(OnBehalfOfCredential, AsyncSupportsTokenInfo) + assert isinstance(SharedTokenCacheCredential, AsyncSupportsTokenInfo) + assert isinstance(VisualStudioCodeCredential, AsyncSupportsTokenInfo) + assert isinstance(WorkloadIdentityCredential, AsyncSupportsTokenInfo) + assert isinstance(DefaultAzureCredential, AsyncSupportsTokenInfo) + assert isinstance(ChainedTokenCredential, AsyncSupportsTokenInfo) + assert isinstance(AzureCliCredential, AsyncSupportsTokenInfo) + assert isinstance(AzurePowerShellCredential, AsyncSupportsTokenInfo) + assert isinstance(AzureDeveloperCliCredential, AsyncSupportsTokenInfo) + assert isinstance(AzurePipelinesCredential, AsyncSupportsTokenInfo) diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index fc1ec6ad855..d623f8b8004 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -17,7 +17,7 @@ import pytest from urllib.parse import urlparse from unittest.mock import Mock, patch -from helpers import build_aad_response, get_discovery_response, id_token_claims +from helpers import build_aad_response, get_discovery_response, id_token_claims, GET_TOKEN_METHODS # fake object for tests which need to exercise request_token but don't care about its return value @@ -49,15 +49,17 @@ class MockCredential(InteractiveCredential): return self._request_token_impl(*scopes, **kwargs) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" request_token = Mock(side_effect=Exception("credential shouldn't begin interactive authentication")) with pytest.raises(ValueError): - MockCredential(request_token=request_token).get_token() + getattr(MockCredential(request_token=request_token), get_token_method)() -def test_authentication_record_argument(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_argument(get_token_method): """The credential should initialize its msal.ClientApplication with values from a given record""" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") @@ -72,12 +74,13 @@ def test_authentication_record_argument(): credential = MockCredential(authentication_record=record, disable_automatic_authentication=True) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", mock_client_application): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_client_application.call_count == 1, "credential didn't create an msal application" -def test_enable_support_logging(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_enable_support_logging(get_token_method): """The keyword argument for enabling PII in MSAL should be passed.""" record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") @@ -95,14 +98,15 @@ def test_enable_support_logging(): ) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", mock_client_application): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_client_application.call_count == 1, "credential didn't create an msal application" _, kwargs = mock_client_application.call_args assert kwargs["enable_pii_log"] -def test_tenant_argument_overrides_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_argument_overrides_record(get_token_method): """The 'tenant_ic' keyword argument should override a given record's value""" tenant_id = "some-guid" @@ -121,10 +125,11 @@ def test_tenant_argument_overrides_record(): ) with pytest.raises(AuthenticationRequiredError): with patch("msal.PublicClientApplication", validate_authority): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_disable_automatic_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_disable_automatic_authentication(get_token_method): """When silent auth fails the credential should raise, if it's configured not to authenticate automatically""" expected_details = "something went wrong" @@ -144,14 +149,18 @@ def test_disable_automatic_authentication(): expected_claims = "..." with pytest.raises(AuthenticationRequiredError) as ex: with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token(scope, claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)(scope, **kwargs) # the exception should carry the requested scopes and claims, and any error message from Microsoft Entra ID assert ex.value.scopes == (scope,) assert ex.value.claims == expected_claims -def test_scopes_round_trip(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_scopes_round_trip(get_token_method): """authenticate should accept the value of AuthenticationRequiredError.scopes""" scope = "scope" @@ -163,7 +172,7 @@ def test_scopes_round_trip(): request_token = Mock(wraps=validate_scopes) credential = MockCredential(disable_automatic_authentication=True, request_token=request_token) with pytest.raises(AuthenticationRequiredError) as ex: - credential.get_token(scope) + getattr(credential, get_token_method)(scope) credential.authenticate(scopes=ex.value.scopes) @@ -191,7 +200,8 @@ def test_authenticate_default_scopes(authority, expected_scope): assert request_token.call_count == 1 -def test_authenticate_unknown_cloud(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authenticate_unknown_cloud(get_token_method): """authenticate should raise when given no scopes in an unknown cloud""" with pytest.raises(CredentialUnavailableError): @@ -207,7 +217,8 @@ def test_authenticate_ignores_disable_automatic_authentication(option): assert request_token.call_count == 1, "credential didn't begin interactive authentication" -def test_get_token_wraps_exceptions(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_get_token_wraps_exceptions(get_token_method): """get_token shouldn't propagate exceptions from MSAL""" class CustomException(Exception): @@ -222,13 +233,14 @@ def test_get_token_wraps_exceptions(): credential = MockCredential(authentication_record=record) with pytest.raises(ClientAuthenticationError) as ex: with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert expected_message in ex.value.message assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth" -def test_token_cache_persistent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_cache_persistent(get_token_method): """the credential should default to an in memory cache, and optionally use a persistent cache""" class TestCredential(InteractiveCredential): @@ -255,12 +267,15 @@ def test_token_cache_persistent(): ) assert not load_persistent_cache.called - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert load_persistent_cache.call_count == 1 assert credential._cache is not None assert credential._cae_cache is None - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert load_persistent_cache.call_count == 2 assert credential._cae_cache is not None @@ -268,15 +283,16 @@ def test_token_cache_persistent(): assert credential2._cache is None assert credential2._cae_cache is None - credential2.get_token("scope") + getattr(credential2, get_token_method)("scope") assert isinstance(credential2._cache, TokenCache) assert credential2._cae_cache is None - credential2.get_token("scope", enable_cae=True) + getattr(credential2, get_token_method)("scope", **kwargs) assert isinstance(credential2._cae_cache, TokenCache) -def test_home_account_id_client_info(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_home_account_id_client_info(get_token_method): """when MSAL returns client_info, the credential should decode it to get the home_account_id""" object_id = "object-id" @@ -302,7 +318,8 @@ def test_home_account_id_client_info(): assert record.home_account_id == "{}.{}".format(object_id, home_tenant) -def test_adfs(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_adfs(get_token_method): """the credential should be able to construct an AuthenticationRecord from an ADFS response returned by MSAL""" authority = "localhost" @@ -333,7 +350,8 @@ def test_adfs(): assert record.username == username -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -368,21 +386,28 @@ def test_multitenant_authentication(): transport=Mock(send=send), additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_tenant = "expected-tenant" expected_token = "***" @@ -410,12 +435,34 @@ def test_multitenant_authentication_not_allowed(): credential = MockCredential(tenant_id=expected_tenant, transport=Mock(send=send), request_token=request_token) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=expected_tenant) + kwargs = {"tenant_id": expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="un" + expected_tenant) + kwargs = {"tenant_id": "un" + expected_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + + +def test_arbitrary_kwargs_propagated_get_token_info(): + """For intermediary testing of PoP support.""" + + class TestCredential(InteractiveCredential): + def __init__(self, **kwargs): + super(TestCredential, self).__init__(client_id="...", **kwargs) + + def _request_token(self, *_, **kwargs): + assert "foo" in kwargs + raise ValueError("Raising here since keyword arg was propagated") + + credential = TestCredential() + with pytest.raises(ValueError): + credential.get_token_info("scope", options={"foo": "bar"}) # type: ignore diff --git a/sdk/identity/azure-identity/tests/test_live.py b/sdk/identity/azure-identity/tests/test_live.py index f57244dd843..f475404a5dc 100644 --- a/sdk/identity/azure-identity/tests/test_live.py +++ b/sdk/identity/azure-identity/tests/test_live.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import pytest from azure.identity import ( @@ -16,31 +17,33 @@ from azure.identity import ( ) from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID -from helpers import get_token_payload_contents +from helpers import get_token_payload_contents, GET_TOKEN_METHODS ARM_SCOPE = "https://management.azure.com/.default" -def get_token(credential, **kwargs): - token = credential.get_token(ARM_SCOPE, **kwargs) +def get_token(credential, method, **kwargs): + token = getattr(credential, method)(ARM_SCOPE, **kwargs) assert token assert token.token assert token.expires_on return token -@pytest.mark.parametrize("certificate_fixture", ("live_pem_certificate", "live_pfx_certificate")) -def test_certificate_credential(certificate_fixture, request): +@pytest.mark.parametrize( + "certificate_fixture,get_token_method", product(("live_pem_certificate", "live_pfx_certificate"), GET_TOKEN_METHODS) +) +def test_certificate_credential(certificate_fixture, get_token_method, request): cert = request.getfixturevalue(certificate_fixture) tenant_id = cert["tenant_id"] client_id = cert["client_id"] credential = CertificateCredential(tenant_id, client_id, cert["cert_path"]) - get_token(credential) + get_token(credential, get_token_method) credential = CertificateCredential(tenant_id, client_id, certificate_data=cert["cert_bytes"]) - token = get_token(credential, enable_cae=True) + token = get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @@ -48,61 +51,68 @@ def test_certificate_credential(certificate_fixture, request): credential = CertificateCredential( tenant_id, client_id, cert["cert_with_password_path"], password=cert["password"] ) - get_token(credential) + get_token(credential, get_token_method) credential = CertificateCredential( tenant_id, client_id, certificate_data=cert["cert_with_password_bytes"], password=cert["password"] ) - get_token(credential) + get_token(credential, get_token_method) -def test_client_secret_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_secret_credential(live_service_principal, get_token_method): credential = ClientSecretCredential( live_service_principal["tenant_id"], live_service_principal["client_id"], live_service_principal["client_secret"], ) - token = get_token(credential, enable_cae=True) + token = get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] -def test_default_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_default_credential(live_service_principal, get_token_method): credential = DefaultAzureCredential() - get_token(credential) + get_token(credential, get_token_method) -def test_username_password_auth(live_user_details): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_username_password_auth(live_user_details, get_token_method): credential = UsernamePasswordCredential( client_id=live_user_details["client_id"], username=live_user_details["username"], password=live_user_details["password"], tenant_id=live_user_details["tenant"], ) - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cli_credential(get_token_method): credential = AzureCliCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_dev_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_dev_cli_credential(get_token_method): credential = AzureDeveloperCliCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual -def test_powershell_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_powershell_credential(get_token_method): credential = AzurePowerShellCredential() - get_token(credential) + get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.prints -def test_device_code(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_device_code(get_token_method): import webbrowser def prompt(url, user_code, _): @@ -110,4 +120,4 @@ def test_device_code(): webbrowser.open_new_tab(url) credential = DeviceCodeCredential(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, prompt_callback=prompt, timeout=40) - get_token(credential) + get_token(credential, get_token_method) diff --git a/sdk/identity/azure-identity/tests/test_live_async.py b/sdk/identity/azure-identity/tests/test_live_async.py index 073ef3f65e7..8c8ed19f64f 100644 --- a/sdk/identity/azure-identity/tests/test_live_async.py +++ b/sdk/identity/azure-identity/tests/test_live_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import pytest from azure.identity.aio import ( @@ -13,13 +14,13 @@ from azure.identity.aio import ( AzureDeveloperCliCredential, ) -from helpers import get_token_payload_contents +from helpers import get_token_payload_contents, GET_TOKEN_METHODS ARM_SCOPE = "https://management.azure.com/.default" -async def get_token(credential, **kwargs): - token = await credential.get_token(ARM_SCOPE, **kwargs) +async def get_token(credential, get_token_method, **kwargs): + token = await getattr(credential, get_token_method)(ARM_SCOPE, **kwargs) assert token assert token.token assert token.expires_on @@ -27,18 +28,20 @@ async def get_token(credential, **kwargs): @pytest.mark.asyncio -@pytest.mark.parametrize("certificate_fixture", ("live_pem_certificate", "live_pfx_certificate")) -async def test_certificate_credential(certificate_fixture, request): +@pytest.mark.parametrize( + "certificate_fixture,get_token_method", product(("live_pem_certificate", "live_pfx_certificate"), GET_TOKEN_METHODS) +) +async def test_certificate_credential(certificate_fixture, get_token_method, request): cert = request.getfixturevalue(certificate_fixture) tenant_id = cert["tenant_id"] client_id = cert["client_id"] credential = CertificateCredential(tenant_id, client_id, cert["cert_path"]) - await get_token(credential) + await get_token(credential, get_token_method) credential = CertificateCredential(tenant_id, client_id, certificate_data=cert["cert_bytes"]) - token = await get_token(credential, enable_cae=True) + token = await get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @@ -46,48 +49,53 @@ async def test_certificate_credential(certificate_fixture, request): credential = CertificateCredential( tenant_id, client_id, cert["cert_with_password_path"], password=cert["password"] ) - await get_token(credential) + await get_token(credential, get_token_method) credential = CertificateCredential( tenant_id, client_id, certificate_data=cert["cert_with_password_bytes"], password=cert["password"] ) - await get_token(credential, enable_cae=True) + await get_token(credential, get_token_method, enable_cae=True) @pytest.mark.asyncio -async def test_client_secret_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_secret_credential(live_service_principal, get_token_method): credential = ClientSecretCredential( live_service_principal["tenant_id"], live_service_principal["client_id"], live_service_principal["client_secret"], ) - token = await get_token(credential, enable_cae=True) + token = await get_token(credential, get_token_method, enable_cae=True) parsed_payload = get_token_payload_contents(token.token) assert "xms_cc" in parsed_payload and "CP1" in parsed_payload["xms_cc"] @pytest.mark.asyncio -async def test_default_credential(live_service_principal): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_default_credential(live_service_principal, get_token_method): credential = DefaultAzureCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cli_credential(get_token_method): credential = AzureCliCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_dev_cli_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_dev_cli_credential(get_token_method): credential = AzureDeveloperCliCredential() - await get_token(credential) + await get_token(credential, get_token_method) @pytest.mark.manual @pytest.mark.asyncio -async def test_powershell_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_powershell_credential(get_token_method): credential = AzurePowerShellCredential() - await get_token(credential) + await get_token(credential, get_token_method) diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index 54ff51802e7..e14dfe724b7 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -2,13 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import os -import sys +from itertools import product import time from unittest import mock -from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError from azure.identity import ManagedIdentityCredential, CredentialUnavailableError from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH @@ -16,7 +13,7 @@ from azure.identity._internal.user_agent import USER_AGENT from azure.identity._internal import within_credential_chain import pytest -from helpers import build_aad_response, validating_transport, mock_response, Request +from helpers import build_aad_response, validating_transport, mock_response, Request, GET_TOKEN_METHODS MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ" ALL_ENVIRONMENTS = ( @@ -73,8 +70,8 @@ def test_context_manager_incomplete_configuration(): pass -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -def test_custom_hooks(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +def test_custom_hooks(environ, get_token_method): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" @@ -99,7 +96,7 @@ def test_custom_hooks(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - credential.get_token(scope) + getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -108,8 +105,8 @@ def test_custom_hooks(environ): assert pipeline_response.http_response == expected_response -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -def test_tenant_id(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +def test_tenant_id(environ, get_token_method): scope = "scope" expected_token = "***" request_hook = mock.Mock() @@ -132,7 +129,7 @@ def test_tenant_id(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - credential.get_token(scope) + getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -141,12 +138,13 @@ def test_tenant_id(environ): assert pipeline_response.http_response == expected_response -def test_cloud_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( @@ -173,14 +171,16 @@ def test_cloud_shell(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_cloud_shell_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = validating_transport( @@ -207,14 +207,20 @@ def test_cloud_shell_tenant_id(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_azure_ml(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_ml(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -238,9 +244,9 @@ def test_azure_ml(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -254,17 +260,19 @@ def test_azure_ml(): {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on - token = ManagedIdentityCredential(transport=transport, client_id=client_id).get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = getattr(ManagedIdentityCredential(transport=transport, client_id=client_id), get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_azure_ml_tenant_id(): - expected_token = AccessToken("****", int(time.time()) + 3600) +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_azure_ml_tenant_id(get_token_method): + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -288,9 +296,9 @@ def test_azure_ml_tenant_id(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -304,12 +312,16 @@ def test_azure_ml_tenant_id(): {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on -def test_cloud_shell_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cloud_shell_identity_config(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" @@ -349,17 +361,18 @@ def test_cloud_shell_identity_config(): ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on -def test_prefers_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_prefers_app_service_2019_08_01(get_token_method): """When the environment is configured for both App Service versions, the credential should prefer the most recent""" access_token = "****" @@ -395,12 +408,13 @@ def test_prefers_app_service_2019_08_01(): EnvironmentVariables.MSI_SECRET: secret, } with mock.patch.dict("os.environ", environ, clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_2019_08_01(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -442,12 +456,13 @@ def test_app_service_2019_08_01(): }, clear=True, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_2019_08_01_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_2019_08_01_tenant_id(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -489,12 +504,16 @@ def test_app_service_2019_08_01_tenant_id(): }, clear=True, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on -def test_app_service_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_app_service_user_assigned_identity(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" expected_token = "****" @@ -541,20 +560,21 @@ def test_app_service_user_assigned_identity(): {EnvironmentVariables.IDENTITY_ENDPOINT: endpoint, EnvironmentVariables.IDENTITY_HEADER: secret}, clear=True, ): - token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) + token = getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(client_id=client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on -def test_imds(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -582,14 +602,15 @@ def test_imds(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope) - assert token.token == expected_token.token + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token -def test_imds_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = validating_transport( requests=[ @@ -617,11 +638,15 @@ def test_imds_tenant_id(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token -def test_imds_text_response(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_text_response(get_token_method): within_credential_chain.set(True) response = mock.Mock( text=lambda encoding=None: b"{This is a text response}", @@ -632,11 +657,12 @@ def test_imds_text_response(): mock_send = mock.Mock(return_value=response) credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send)) with pytest.raises(CredentialUnavailableError): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") within_credential_chain.set(False) -def test_client_id_none(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_id_none(get_token_method): """the credential should ignore client_id=None""" expected_access_token = "****" @@ -655,7 +681,7 @@ def test_client_id_none(): # IMDS credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # Cloud Shell @@ -663,14 +689,15 @@ def test_client_id_none(): MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token -def test_imds_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_imds_user_assigned_identity(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = IMDS_AUTHORITY + IMDS_TOKEN_PATH scope = "scope" client_id = "some-guid" @@ -701,11 +728,12 @@ def test_imds_user_assigned_identity(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) - assert token.token == expected_token.token + token = getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)(scope) + assert token.token == expected_token -def test_service_fabric(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_service_fabric(get_token_method): """Service Fabric 2019-07-01-preview""" access_token = "****" expires_on = 42 @@ -741,12 +769,13 @@ def test_service_fabric(): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on -def test_service_fabric_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_service_fabric_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -781,12 +810,16 @@ def test_service_fabric_tenant_id(): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on -def test_token_exchange(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_exchange(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -833,7 +866,7 @@ def test_token_exchange(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token # client_id kwarg should override AZURE_CLIENT_ID @@ -857,7 +890,7 @@ def test_token_exchange(tmpdir): with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token # AZURE_CLIENT_ID may not have a value, in which case client_id is required @@ -891,11 +924,12 @@ def test_token_exchange(tmpdir): ManagedIdentityCredential() credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token -def test_token_exchange_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_token_exchange_tenant_id(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -941,7 +975,10 @@ def test_token_exchange_tenant_id(tmpdir): } with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = credential.get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(scope, **kwargs) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 239f074c93f..ee988f67dbf 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -2,11 +2,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from itertools import product import os import time from unittest import mock -from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.identity import CredentialUnavailableError from azure.identity.aio import ManagedIdentityCredential @@ -17,7 +17,7 @@ from azure.identity._internal import within_credential_chain import pytest -from helpers import build_aad_response, mock_response, Request +from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_managed_identity import ALL_ENVIRONMENTS @@ -26,8 +26,8 @@ MANAGED_IDENTITY_ENVIRON = "azure.identity.aio._credentials.managed_identity.os. @pytest.mark.asyncio -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -async def test_custom_hooks(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +async def test_custom_hooks(environ, get_token_method): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" @@ -52,7 +52,7 @@ async def test_custom_hooks(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - await credential.get_token(scope) + await getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -62,8 +62,8 @@ async def test_custom_hooks(environ): @pytest.mark.asyncio -@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) -async def test_tenant_id(environ): +@pytest.mark.parametrize("environ,get_token_method", product(ALL_ENVIRONMENTS, GET_TOKEN_METHODS)) +async def test_tenant_id(environ, get_token_method): scope = "scope" expected_token = "***" request_hook = mock.Mock() @@ -86,7 +86,7 @@ async def test_tenant_id(environ): credential = ManagedIdentityCredential( transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook ) - await credential.get_token(scope) + await getattr(credential, get_token_method)(scope) assert request_hook.call_count == 1 assert response_hook.call_count == 1 @@ -134,12 +134,13 @@ async def test_context_manager_incomplete_configuration(): @pytest.mark.asyncio -async def test_cloud_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( @@ -166,17 +167,18 @@ async def test_cloud_shell(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token @pytest.mark.asyncio -async def test_cloud_shell_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell_tenant_id(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token endpoint = "http://localhost:42/token" scope = "scope" transport = async_validating_transport( @@ -203,15 +205,20 @@ async def test_cloud_shell_tenant_id(): ) with mock.patch("os.environ", {EnvironmentVariables.MSI_ENDPOINT: endpoint}): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token @pytest.mark.asyncio -async def test_azure_ml(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_ml(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -235,9 +242,9 @@ async def test_azure_ml(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -252,21 +259,23 @@ async def test_azure_ml(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, client_id=client_id) - token = await credential.get_token(scope) - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token + assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_ml_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_ml_tenant_id(get_token_method): """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)""" - expected_token = AccessToken("****", int(time.time()) + 3600) + expected_token = "****" + expires_on = int(time.time()) + 3600 url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -290,9 +299,9 @@ async def test_azure_ml_tenant_id(): responses=[ mock_response( json_payload={ - "access_token": expected_token.token, + "access_token": expected_token, "expires_in": 3600, - "expires_on": expected_token.expires_on, + "expires_on": expires_on, "resource": scope, "token_type": "Bearer", } @@ -307,13 +316,17 @@ async def test_azure_ml_tenant_id(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope, tenant_id="tenant_id") - assert token.token == expected_token.token - assert token.expires_on == expected_token.expires_on + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) + assert token.token == expected_token + assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_cloud_shell_identity_config(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cloud_shell_identity_config(get_token_method): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" @@ -354,23 +367,24 @@ async def test_cloud_shell_identity_config(): with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2017_09_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2017_09_01(get_token_method): """When the environment for 2019-08-01 is not configured, 2017-09-01 should be used.""" access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token url = "http://localhost:42/token" secret = "expected-secret" scope = "scope" @@ -414,18 +428,19 @@ async def test_app_service_2017_09_01(): clear=True, ): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token == expected_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) - assert token == expected_token + token = await getattr(credential, get_token_method)(scope) + assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2019_08_01(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2019_08_01(get_token_method): """App Service 2019-08-01: IDENTITY_ENDPOINT, IDENTITY_HEADER set""" access_token = "****" @@ -467,13 +482,14 @@ async def test_app_service_2019_08_01(): }, clear=True, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_2019_08_01_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_2019_08_01_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -513,13 +529,19 @@ async def test_app_service_2019_08_01_tenant_id(): }, clear=True, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)( + scope, **kwargs + ) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_app_service_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_app_service_user_assigned_identity(get_token_method): """App Service 2019-08-01: MSI_ENDPOINT, MSI_SECRET set""" expected_token = "****" @@ -564,20 +586,21 @@ async def test_app_service_user_assigned_identity(): clear=True, ): credential = ManagedIdentityCredential(client_id=client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( client_id=client_id, transport=transport, identity_config={param_name: param_value} ) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_client_id_none(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_id_none(get_token_method): """the credential should ignore client_id=None""" expected_access_token = "****" @@ -596,22 +619,23 @@ async def test_client_id_none(): with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {}, clear=True): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_access_token with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_access_token @pytest.mark.asyncio -async def test_imds(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -639,15 +663,16 @@ async def test_imds(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_tenant_id(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" transport = async_validating_transport( requests=[ @@ -675,15 +700,19 @@ async def test_imds_tenant_id(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") - assert token == expected_token + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_user_assigned_identity(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_user_assigned_identity(get_token_method): access_token = "****" expires_on = 42 - expected_token = AccessToken(access_token, expires_on) + expected_token = access_token scope = "scope" client_id = "some-guid" transport = async_validating_transport( @@ -713,12 +742,15 @@ async def test_imds_user_assigned_identity(): # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential with mock.patch.dict("os.environ", clear=True): - token = await ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) - assert token == expected_token + token = await getattr(ManagedIdentityCredential(client_id=client_id, transport=transport), get_token_method)( + scope + ) + assert token.token == expected_token @pytest.mark.asyncio -async def test_imds_text_response(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_imds_text_response(get_token_method): async def send(request, **kwargs): response = mock.Mock( text=lambda encoding=None: b"{This is a text response}", @@ -731,12 +763,13 @@ async def test_imds_text_response(): within_credential_chain.set(True) credential = ManagedIdentityCredential(transport=mock.Mock(send=send)) with pytest.raises(CredentialUnavailableError): - token = await credential.get_token("") + token = await getattr(credential, get_token_method)("") within_credential_chain.set(False) @pytest.mark.asyncio -async def test_service_fabric(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_service_fabric(get_token_method): """Service Fabric 2019-07-01-preview""" access_token = "****" expires_on = 42 @@ -772,13 +805,14 @@ async def test_service_fabric(): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_service_fabric_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_service_fabric_tenant_id(get_token_method): access_token = "****" expires_on = 42 endpoint = "http://localhost:42/token" @@ -813,13 +847,19 @@ async def test_service_fabric_tenant_id(): EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: thumbprint, }, ): - token = await ManagedIdentityCredential(transport=mock.Mock(send=send)).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=mock.Mock(send=send)), get_token_method)( + scope, **kwargs + ) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc(tmpdir, get_token_method): """Azure Arc 2020-06-01""" access_token = "****" api_version = "2020-06-01" @@ -868,13 +908,14 @@ async def test_azure_arc(tmpdir): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): - token = await ManagedIdentityCredential(transport=transport).get_token(scope) + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_tenant_id(tmpdir, get_token_method): access_token = "****" api_version = "2020-06-01" expires_on = 42 @@ -922,13 +963,17 @@ async def test_azure_arc_tenant_id(tmpdir): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None): - token = await ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope, **kwargs) assert token.token == access_token assert token.expires_on == expires_on @pytest.mark.asyncio -async def test_azure_arc_client_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_client_id(get_token_method): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( "os.environ", @@ -940,11 +985,12 @@ async def test_azure_arc_client_id(): credential = ManagedIdentityCredential(client_id="some-guid") with pytest.raises(ClientAuthenticationError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_azure_arc_key_too_large(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_too_large(tmp_path, get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -974,12 +1020,13 @@ async def test_azure_arc_key_too_large(tmp_path): ): with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "file size" in str(ex.value) @pytest.mark.asyncio -async def test_azure_arc_key_not_exist(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_not_exist(get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -1003,12 +1050,13 @@ async def test_azure_arc_key_not_exist(tmp_path): {EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint}, ): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "not exist" in str(ex.value) @pytest.mark.asyncio -async def test_azure_arc_key_invalid(tmp_path): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_azure_arc_key_invalid(tmp_path, get_token_method): api_version = "2020-06-01" identity_endpoint = "http://localhost:42/token" imds_endpoint = "http://localhost:42" @@ -1043,17 +1091,18 @@ async def test_azure_arc_key_invalid(tmp_path): ): with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "Unexpected file path" in str(ex.value) with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)): with pytest.raises(ClientAuthenticationError) as ex: - await ManagedIdentityCredential(transport=transport).get_token(scope) + await getattr(ManagedIdentityCredential(transport=transport), get_token_method)(scope) assert "extension" in str(ex.value) @pytest.mark.asyncio -async def test_token_exchange(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_exchange(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -1100,7 +1149,7 @@ async def test_token_exchange(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token # client_id kwarg should override AZURE_CLIENT_ID @@ -1124,7 +1173,7 @@ async def test_token_exchange(tmpdir): with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token # AZURE_CLIENT_ID may not have a value, in which case client_id is required @@ -1158,12 +1207,13 @@ async def test_token_exchange(tmpdir): ManagedIdentityCredential() credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token @pytest.mark.asyncio -async def test_token_exchange_tenant_id(tmpdir): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_token_exchange_tenant_id(tmpdir, get_token_method): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) @@ -1210,7 +1260,10 @@ async def test_token_exchange_tenant_id(tmpdir): # credential should default to AZURE_CLIENT_ID with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope, tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py b/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py index 2c5d8c385bf..f86df54a172 100644 --- a/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py +++ b/sdk/identity/azure-identity/tests/test_multi_tenant_auth.py @@ -11,11 +11,13 @@ from azure.core import PipelineClient from azure.core.rest import HttpRequest, HttpResponse from azure.identity import ClientSecretCredential +from helpers import GET_TOKEN_METHODS + class TestMultiTenantAuth(AzureRecordedTestCase): - def _send_request(self, credential: ClientSecretCredential) -> HttpResponse: + def _send_request(self, credential: ClientSecretCredential, get_token_method: str) -> HttpResponse: client = PipelineClient(base_url="https://graph.microsoft.com") - token = credential.get_token("https://graph.microsoft.com/.default") + token = getattr(credential, get_token_method)("https://graph.microsoft.com/.default") headers = {"Authorization": "Bearer " + token.token, "ConsistencyLevel": "eventual"} request = HttpRequest("GET", "https://graph.microsoft.com/v1.0/applications/$count", headers=headers) response = client.send_request(request) @@ -26,11 +28,12 @@ class TestMultiTenantAuth(AzureRecordedTestCase): is_live() and not os.environ.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID"), reason="Multi-tenant envvars not configured.", ) - def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables, get_token_method): client_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID") tenant_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_TENANT_ID") client_secret = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_SECRET") credential = ClientSecretCredential(tenant_id, client_id, client_secret) - response = self._send_request(credential) + response = self._send_request(credential, get_token_method) assert response.status_code == 200 assert int(response.text()) > 0 diff --git a/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py b/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py index c6ca0547776..6c1e50a7a0d 100644 --- a/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py +++ b/sdk/identity/azure-identity/tests/test_multi_tenant_auth_async.py @@ -11,11 +11,13 @@ from azure.core import AsyncPipelineClient from azure.core.rest import HttpRequest, HttpResponse from azure.identity.aio import ClientSecretCredential +from helpers import GET_TOKEN_METHODS + class TestMultiTenantAuthAsync(AzureRecordedTestCase): - async def _send_request(self, credential: ClientSecretCredential) -> HttpResponse: + async def _send_request(self, credential: ClientSecretCredential, get_token_method: str) -> HttpResponse: client = AsyncPipelineClient(base_url="https://graph.microsoft.com") - token = await credential.get_token("https://graph.microsoft.com/.default") + token = await getattr(credential, get_token_method)("https://graph.microsoft.com/.default") headers = {"Authorization": "Bearer " + token.token, "ConsistencyLevel": "eventual"} request = HttpRequest("GET", "https://graph.microsoft.com/v1.0/applications/$count", headers=headers) response = await client.send_request(request, stream=False) @@ -27,12 +29,13 @@ class TestMultiTenantAuthAsync(AzureRecordedTestCase): is_live() and not os.environ.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID"), reason="Multi-tenant envvars not configured.", ) - async def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables): + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) + async def test_multi_tenant_client_secret_graph_call(self, recorded_test, environment_variables, get_token_method): client_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_ID") tenant_id = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_TENANT_ID") client_secret = environment_variables.get("AZURE_IDENTITY_MULTI_TENANT_CLIENT_SECRET") credential = ClientSecretCredential(tenant_id, client_id, client_secret) async with credential: - response = await self._send_request(credential) + response = await self._send_request(credential, get_token_method) assert response.status_code == 200 assert int(response.text()) > 0 diff --git a/sdk/identity/azure-identity/tests/test_obo.py b/sdk/identity/azure-identity/tests/test_obo.py index 3678b4fa850..bbcd50e6c60 100644 --- a/sdk/identity/azure-identity/tests/test_obo.py +++ b/sdk/identity/azure-identity/tests/test_obo.py @@ -3,11 +3,8 @@ # Licensed under the MIT License. # ------------------------------------ import os - -try: - from unittest.mock import Mock, patch -except ImportError: - from mock import Mock, patch # type: ignore +from itertools import product +from unittest.mock import Mock, patch from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import OnBehalfOfCredential, UsernamePasswordCredential @@ -17,7 +14,7 @@ from azure.identity._internal.user_agent import USER_AGENT import pytest from urllib.parse import urlparse -from helpers import build_aad_response, FAKE_CLIENT_ID, get_discovery_response, mock_response +from helpers import build_aad_response, FAKE_CLIENT_ID, get_discovery_response, mock_response, GET_TOKEN_METHODS from recorded_test_case import RecordedTestCase from test_certificate_credential import PEM_CERT_PATH from devtools_testutils import is_live, recorded_by_proxy @@ -95,7 +92,8 @@ class TestObo(RecordedTestCase): credential.get_token(self.obo_settings["scope"]) -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -124,22 +122,28 @@ def test_multitenant_authentication(): transport=transport, additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -156,7 +160,7 @@ def test_authority(authority): ) with patch("msal.ConfidentialClientApplication", mock_ctor): # must call get_token because the credential constructs the MSAL application lazily - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -167,7 +171,7 @@ def test_authority(authority): with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = OnBehalfOfCredential(tenant_id, "client-id", client_secret="secret", user_assertion="assertion") with patch("msal.ConfidentialClientApplication", mock_ctor): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_ctor.call_count == 1 _, kwargs = mock_ctor.call_args @@ -185,16 +189,18 @@ def test_tenant_id_validation(): OnBehalfOfCredential(tenant, "client-id", client_secret="secret", user_assertion="assertion") -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="client-secret", user_assertion="assertion" ) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(), on_exception=lambda _: False) def send(request, **kwargs): @@ -215,7 +221,7 @@ def test_policies_configurable(): policies=[ContentDecodePolicy(), policy], transport=Mock(send=send), ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -231,7 +237,8 @@ def test_no_client_credential(): credential = OnBehalfOfCredential("tenant-id", "client-id", user_assertion="assertion") -def test_client_assertion_func(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_assertion_func(get_token_method): """The credential should accept a client_assertion_func""" expected_client_assertion = "client-assertion" expected_user_assertion = "user-assertion" @@ -263,7 +270,7 @@ def test_client_assertion_func(): transport=transport, ) - access_token = credential.get_token("scope") + access_token = getattr(credential, get_token_method)("scope") assert access_token.token == expected_token assert func_call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_obo_async.py b/sdk/identity/azure-identity/tests/test_obo_async.py index f9d80323e84..5e90040a0af 100644 --- a/sdk/identity/azure-identity/tests/test_obo_async.py +++ b/sdk/identity/azure-identity/tests/test_obo_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os +from itertools import product from urllib.parse import urlparse from unittest.mock import Mock, patch from test_certificate_credential import PEM_CERT_PATH @@ -17,7 +18,7 @@ from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import OnBehalfOfCredential import pytest -from helpers import build_aad_response, get_discovery_response, mock_response, FAKE_CLIENT_ID +from helpers import build_aad_response, get_discovery_response, mock_response, FAKE_CLIENT_ID, GET_TOKEN_METHODS from helpers_async import AsyncMockTransport from recorded_test_case import RecordedTestCase @@ -123,7 +124,8 @@ async def test_context_manager(): @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_tenant = "first-tenant" first_token = "***" second_tenant = "second-tenant" @@ -149,27 +151,33 @@ async def test_multitenant_authentication(): transport=transport, additionally_allowed_tenants=["*"], ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token assert transport.send.call_count == 1 - token = await credential.get_token("scope", tenant_id=first_tenant) + kwargs = {"tenant_id": first_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token assert transport.send.call_count == 1 # should be a cached token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token assert transport.send.call_count == 2 # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token assert transport.send.call_count == 2 # should be a cached token @pytest.mark.asyncio -@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) -async def test_authority(authority): +@pytest.mark.parametrize("authority,get_token_method", product(("localhost", "https://localhost"), GET_TOKEN_METHODS)) +async def test_authority(authority, get_token_method): """the credential should accept an authority, with or without scheme, as an argument or environment variable""" tenant_id = "expected-tenant" @@ -194,7 +202,7 @@ async def test_authority(authority): authority=authority, transport=transport, ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token # authority can be configured via environment variable @@ -202,12 +210,13 @@ async def test_authority(authority): credential = OnBehalfOfCredential( tenant_id, "client-id", client_secret="secret", user_assertion="assertion", transport=transport ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock(), on_exception=lambda _: False) async def send(request, **kwargs): @@ -228,7 +237,7 @@ async def test_policies_configurable(): policies=[ContentDecodePolicy(), policy], transport=Mock(send=send), ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @@ -239,7 +248,8 @@ def test_invalid_cert(): @pytest.mark.asyncio -async def test_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_refresh_token(get_token_method): first_token = "***" second_token = first_token * 2 refresh_token = "refresh-token" @@ -264,10 +274,10 @@ async def test_refresh_token(): credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="secret", user_assertion="assertion", transport=Mock(send=send) ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == second_token assert requests == 2 @@ -285,13 +295,14 @@ def test_tenant_id_validation(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" credential = OnBehalfOfCredential( "tenant-id", "client-id", client_secret="client-secret", user_assertion="assertion" ) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio @@ -309,7 +320,8 @@ async def test_no_client_credential(): @pytest.mark.asyncio -async def test_client_assertion_func(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_client_assertion_func(get_token_method): """The credential should accept a client_assertion_func""" expected_client_assertion = "client-assertion" expected_user_assertion = "user-assertion" @@ -340,7 +352,7 @@ async def test_client_assertion_func(): user_assertion=expected_user_assertion, transport=transport, ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token assert func_call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential.py b/sdk/identity/azure-identity/tests/test_powershell_credential.py index 9441f1c6975..23b663a667e 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential.py @@ -3,17 +3,14 @@ # Licensed under the MIT License. # ------------------------------------ import base64 +from itertools import product import logging from platform import python_version import re import subprocess import sys import time - -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore +from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzurePowerShellCredential, CredentialUnavailableError @@ -28,7 +25,7 @@ from azure.identity._credentials.azure_powershell import ( import pytest from credscan_ignore import POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS POPEN = AzurePowerShellCredential.__module__ + ".subprocess.Popen" @@ -48,29 +45,33 @@ def get_mock_Popen(return_code=0, stdout="", stderr=""): return Mock(return_value=Mock(communicate=communicate, returncode=return_code)) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - AzurePowerShellCredential().get_token() + getattr(AzurePowerShellCredential(), get_token_method)() -def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("one scope", "and another") + getattr(AzurePowerShellCredential(), get_token_method)("one scope", "and another") -def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with patch(POPEN, Mock(side_effect=OSError)): with pytest.raises(CredentialUnavailableError): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -78,19 +79,23 @@ def test_invalid_tenant_id(): AzurePowerShellCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(AzurePowerShellCredential(), get_token_method)("scope", **kwargs) -def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - AzurePowerShellCredential().get_token("scope" + c) + getattr(AzurePowerShellCredential(), get_token_method)("scope" + c) -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -def test_get_token(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +def test_get_token(stderr, get_token_method): """The credential should parse Azure PowerShell's output to an AccessToken""" expected_access_token = "access" @@ -100,7 +105,7 @@ def test_get_token(stderr): Popen = get_mock_Popen(stdout=stdout, stderr=stderr) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token(scope) + token = getattr(AzurePowerShellCredential(), get_token_method)(scope) assert token.token == expected_access_token assert token.expires_on == expected_expires_on @@ -123,8 +128,8 @@ def test_get_token(stderr): assert "timeout" in kwargs -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -def test_get_token_tenant_id(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +def test_get_token_tenant_id(stderr, get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 scope = "scope" @@ -132,69 +137,82 @@ def test_get_token_tenant_id(stderr): Popen = get_mock_Popen(stdout=stdout, stderr=stderr) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token(scope, tenant_id="tenant-id") + kwargs = {"tenant_id": "tenant-id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(AzurePowerShellCredential(), get_token_method)(scope, **kwargs) assert token.token == expected_access_token assert token.expires_on == expected_expires_on -def test_ignores_extraneous_stdout_content(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_ignores_extraneous_stdout_content(get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 motd = "MOTD: Customize your experience: save your profile to $HOME/.config/PowerShell\n" Popen = get_mock_Popen(stdout=motd + "azsdk%{}%{}".format(expected_access_token, expected_expires_on)) with patch(POPEN, Popen): - token = AzurePowerShellCredential().get_token("scope") + token = getattr(AzurePowerShellCredential(), get_token_method)("scope") assert token.token == expected_access_token assert token.expires_on == expected_expires_on -def test_az_powershell_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_az_powershell_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when Azure PowerShell isn't installed""" with patch(POPEN, get_mock_Popen(stdout=NO_AZ_ACCOUNT_MODULE)): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") @pytest.mark.parametrize( - "stderr", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + "stderr,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + ), + GET_TOKEN_METHODS, ), ) -def test_powershell_not_installed_cmd(stderr): +def test_powershell_not_installed_cmd(stderr, get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_powershell_not_installed_sh(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_powershell_not_installed_sh(get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" Popen = get_mock_Popen(return_code=127, stderr="/bin/sh: 0: Can't open pwsh") with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.parametrize("stderr", (POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR)) -def test_not_logged_in(stderr): +@pytest.mark.parametrize( + "stderr,get_token_method", + product((POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR), GET_TOKEN_METHODS), +) +def test_not_logged_in(stderr, get_token_method): """The credential should raise CredentialUnavailableError when a user isn't logged in to Azure PowerShell""" Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=RUN_CONNECT_AZ_ACCOUNT): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -def test_blocked_by_execution_policy(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_blocked_by_execution_policy(get_token_method): """The credential should raise CredentialUnavailableError when execution policy blocks Get-AzAccessToken""" stderr = r"""#< CLIXML @@ -202,11 +220,11 @@ def test_blocked_by_execution_policy(): Popen = get_mock_Popen(return_code=1, stderr=stderr) with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=BLOCKED_BY_EXECUTION_POLICY): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.skipif(sys.version_info < (3, 3), reason="Python 3.3 added timeout support to Popen") -def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" from subprocess import TimeoutExpired @@ -214,7 +232,7 @@ def test_timeout(): proc = Mock(communicate=Mock(side_effect=TimeoutExpired("", 42)), returncode=None) with patch(POPEN, Mock(return_value=proc)): with pytest.raises(CredentialUnavailableError): - AzurePowerShellCredential(process_timeout=42).get_token("scope") + getattr(AzurePowerShellCredential(process_timeout=42), get_token_method)("scope") assert proc.communicate.call_count == 1 # Ensure custom timeout is passed to subprocess @@ -223,7 +241,8 @@ def test_timeout(): assert kwargs["timeout"] == 42 -def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_unexpected_error(get_token_method): """The credential should log stderr when Get-AzAccessToken returns an unexpected error""" class MockHandler(logging.Handler): @@ -243,7 +262,7 @@ def test_unexpected_error(): Popen = get_mock_Popen(return_code=42, stderr=expected_output) with patch(POPEN, Popen): with pytest.raises(ClientAuthenticationError): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") for message in mock_handler.messages: if message.levelname == "DEBUG" and expected_output in message.message: @@ -253,13 +272,16 @@ def test_unexpected_error(): @pytest.mark.parametrize( - "error_message", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "some other message", + "error_message,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "some other message", + ), + GET_TOKEN_METHODS, ), ) -def test_windows_powershell_fallback(error_message): +def test_windows_powershell_fallback(error_message, get_token_method): """On Windows, the credential should fall back to powershell.exe when pwsh.exe isn't on the path""" class Fake: @@ -285,12 +307,13 @@ def test_windows_powershell_fallback(error_message): with patch.dict("os.environ", {"SYSTEMROOT": "foo"}): with patch(POPEN, Popen): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - AzurePowerShellCredential().get_token("scope") + getattr(AzurePowerShellCredential(), get_token_method)("scope") assert Fake.calls == 2 -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "12345" second_token = first_token * 2 @@ -314,18 +337,22 @@ def test_multitenant_authentication(): credential = AzurePowerShellCredential() with patch(POPEN, fake_Popen): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): expected_token = "***" def fake_Popen(command, **_): @@ -346,9 +373,12 @@ def test_multitenant_authentication_not_allowed(): credential = AzurePowerShellCredential() with patch(POPEN, fake_Popen): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="12345") + kwargs = {"tenant_id": "12345"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py index bc694b7ef01..2764575655e 100644 --- a/sdk/identity/azure-identity/tests/test_powershell_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_powershell_credential_async.py @@ -4,6 +4,7 @@ # ------------------------------------ import asyncio import base64 +from itertools import product import logging import re import sys @@ -24,7 +25,7 @@ from azure.identity._credentials.azure_powershell import ( import pytest from credscan_ignore import POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR -from helpers import INVALID_CHARACTERS +from helpers import INVALID_CHARACTERS, GET_TOKEN_METHODS from helpers_async import get_completed_future from test_powershell_credential import PREPARING_MODULES @@ -38,29 +39,33 @@ def get_mock_exec(return_code=0, stdout="", stderr=""): return Mock(return_value=get_completed_future(Mock(communicate=communicate, returncode=return_code))) -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise ValueError when get_token is called with no scopes""" with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token() + await getattr(AzurePowerShellCredential(), get_token_method)() -async def test_multiple_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multiple_scopes(get_token_method): """The credential should raise ValueError when get_token is called with more than one scope""" with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("one scope", "and another") + await getattr(AzurePowerShellCredential(), get_token_method)("one scope", "and another") -async def test_cannot_execute_shell(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cannot_execute_shell(get_token_method): """The credential should raise CredentialUnavailableError when the subprocess doesn't start""" with patch(CREATE_SUBPROCESS_EXEC, Mock(side_effect=OSError)): with pytest.raises(CredentialUnavailableError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_invalid_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_tenant_id(get_token_method): """Invalid tenant IDs should raise ValueErrors.""" for c in INVALID_CHARACTERS: @@ -68,19 +73,23 @@ async def test_invalid_tenant_id(): AzurePowerShellCredential(tenant_id="tenant" + c) with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("scope", tenant_id="tenant" + c) + kwargs = {"tenant_id": "tenant" + c} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(AzurePowerShellCredential(), get_token_method)("scope", **kwargs) -async def test_invalid_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_invalid_scopes(get_token_method): """Scopes with invalid characters should raise ValueErrors.""" for c in INVALID_CHARACTERS: with pytest.raises(ValueError): - await AzurePowerShellCredential().get_token("scope" + c) + await getattr(AzurePowerShellCredential(), get_token_method)("scope" + c) -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -async def test_get_token(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +async def test_get_token(stderr, get_token_method): """The credential should parse Azure PowerShell's output to an AccessToken""" expected_access_token = "access" @@ -90,7 +99,7 @@ async def test_get_token(stderr): mock_exec = get_mock_exec(stdout=stdout, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token(scope) + token = await getattr(AzurePowerShellCredential(), get_token_method)(scope) assert token.token == expected_access_token assert token.expires_on == expected_expires_on @@ -110,8 +119,8 @@ async def test_get_token(stderr): assert mock_exec().result().communicate.call_count == 1 -@pytest.mark.parametrize("stderr", ("", PREPARING_MODULES)) -async def test_get_token_tenant_id(stderr): +@pytest.mark.parametrize("stderr,get_token_method", product(("", PREPARING_MODULES), GET_TOKEN_METHODS)) +async def test_get_token_tenant_id(stderr, get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 scope = "scope" @@ -119,69 +128,82 @@ async def test_get_token_tenant_id(stderr): mock_exec = get_mock_exec(stdout=stdout, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token(scope, tenant_id="tenant-id") + kwargs = {"tenant_id": "tenant-id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(AzurePowerShellCredential(), get_token_method)(scope, **kwargs) assert token.token == expected_access_token assert token.expires_on == expected_expires_on -async def test_ignores_extraneous_stdout_content(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_ignores_extraneous_stdout_content(get_token_method): expected_access_token = "access" expected_expires_on = 1617923581 motd = "MOTD: Customize your experience: save your profile to $HOME/.config/PowerShell\n" mock_exec = get_mock_exec(stdout=motd + "azsdk%{}%{}".format(expected_access_token, expected_expires_on)) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - token = await AzurePowerShellCredential().get_token("scope") + token = await getattr(AzurePowerShellCredential(), get_token_method)("scope") assert token.token == expected_access_token assert token.expires_on == expected_expires_on -async def test_az_powershell_not_installed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_az_powershell_not_installed(get_token_method): """The credential should raise CredentialUnavailableError when Azure PowerShell isn't installed""" with patch(CREATE_SUBPROCESS_EXEC, get_mock_exec(stdout=NO_AZ_ACCOUNT_MODULE)): with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") @pytest.mark.parametrize( - "stderr", - ( - "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", - "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + "stderr,get_token_method", + product( + ( + "'pwsh' is not recognized as an internal or external command,\r\noperable program or batch file.", + "'powershell' is not recognized as an internal or external command,\r\noperable program or batch file.", + ), + GET_TOKEN_METHODS, ), ) -async def test_powershell_not_installed_cmd(stderr): +async def test_powershell_not_installed_cmd(stderr, get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_powershell_not_installed_sh(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_powershell_not_installed_sh(get_token_method): """The credential should raise CredentialUnavailableError when PowerShell isn't installed""" mock_exec = get_mock_exec(return_code=127, stderr="/bin/sh: 0: Can't open pwsh") with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=POWERSHELL_NOT_INSTALLED): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -@pytest.mark.parametrize("stderr", (POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR)) -async def test_not_logged_in(stderr): +@pytest.mark.parametrize( + "stderr,get_token_method", + product((POWERSHELL_INVALID_OPERATION_EXCEPTION, POWERSHELL_NOT_LOGGED_IN_ERROR), GET_TOKEN_METHODS), +) +async def test_not_logged_in(stderr, get_token_method): """The credential should raise CredentialUnavailableError when a user isn't logged in to Azure PowerShell""" mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=RUN_CONNECT_AZ_ACCOUNT): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_blocked_by_execution_policy(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_blocked_by_execution_policy(get_token_method): """The credential should raise CredentialUnavailableError when execution policy blocks Get-AzAccessToken""" stderr = r"""#< CLIXML @@ -189,21 +211,23 @@ async def test_blocked_by_execution_policy(): mock_exec = get_mock_exec(return_code=1, stderr=stderr) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(CredentialUnavailableError, match=BLOCKED_BY_EXECUTION_POLICY): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") -async def test_timeout(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_timeout(get_token_method): """The credential should kill the subprocess after a timeout""" proc = Mock(communicate=Mock(side_effect=asyncio.TimeoutError), returncode=None) with patch(CREATE_SUBPROCESS_EXEC, Mock(return_value=get_completed_future(proc))): with pytest.raises(CredentialUnavailableError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") assert proc.communicate.call_count == 1 -async def test_unexpected_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_unexpected_error(get_token_method): """The credential should log stderr when Get-AzAccessToken returns an unexpected error""" class MockHandler(logging.Handler): @@ -223,7 +247,7 @@ async def test_unexpected_error(): mock_exec = get_mock_exec(return_code=42, stderr=expected_output) with patch(CREATE_SUBPROCESS_EXEC, mock_exec): with pytest.raises(ClientAuthenticationError): - await AzurePowerShellCredential().get_token("scope") + await getattr(AzurePowerShellCredential(), get_token_method)("scope") for message in mock_handler.messages: if message.levelname == "DEBUG" and expected_output in message.message: @@ -233,17 +257,22 @@ async def test_unexpected_error(): @pytest.mark.skipif(not sys.platform.startswith("win"), reason="tests Windows-specific behavior") -async def test_windows_event_loop(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_event_loop(get_token_method): """The credential should fall back to the sync implementation when not using ProactorEventLoop on Windows""" sync_get_token = Mock() credential = AzurePowerShellCredential() with patch(AzurePowerShellCredential.__module__ + "._SyncCredential") as fallback: - fallback.return_value = Mock(spec_set=["get_token"], get_token=sync_get_token) + fallback.return_value = Mock( + spec_set=["get_token", "get_token_info"], + get_token=sync_get_token, + get_token_info=sync_get_token, + ) with patch(AzurePowerShellCredential.__module__ + ".asyncio.get_event_loop"): # asyncio.get_event_loop now returns Mock, i.e. never ProactorEventLoop - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert sync_get_token.call_count == 1 @@ -256,7 +285,8 @@ async def test_windows_event_loop(): "some other message", ), ) -async def test_windows_powershell_fallback(error_message): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_windows_powershell_fallback(error_message, get_token_method): """On Windows, the credential should fall back to powershell.exe when pwsh.exe isn't on the path""" calls = 0 @@ -282,12 +312,13 @@ async def test_windows_powershell_fallback(error_message): credential = AzurePowerShellCredential() with pytest.raises(CredentialUnavailableError, match=AZ_ACCOUNT_NOT_INSTALLED): with patch(CREATE_SUBPROCESS_EXEC, mock_exec): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert calls == 2 -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "12345" second_token = first_token * 2 @@ -312,18 +343,22 @@ async def test_multitenant_authentication(): credential = AzurePowerShellCredential() with patch(CREATE_SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): expected_token = "***" async def fake_exec(*args, **_): @@ -344,9 +379,12 @@ async def test_multitenant_authentication_not_allowed(): credential = AzurePowerShellCredential() with patch(CREATE_SUBPROCESS_EXEC, fake_exec): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="12345") + kwargs = {"tenant_id": "12345"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index ba1052ac52c..66709b61f0d 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -35,14 +35,16 @@ from helpers import ( msal_validating_transport, Request, validating_transport, + GET_TOKEN_METHODS, ) -def test_close(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_close(get_token_method): transport = MagicMock() credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert not transport.__enter__.called assert not transport.__exit__.called @@ -52,11 +54,12 @@ def test_close(): assert transport.__exit__.call_count == 1 -def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_context_manager(get_token_method): transport = MagicMock() credential = SharedTokenCacheCredential(transport=transport, _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert not transport.__enter__.called assert not transport.__exit__.called @@ -92,15 +95,17 @@ def test_supported(): assert SharedTokenCacheCredential.supported() -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = SharedTokenCacheCredential(_cache=TokenCache()) with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) def send(*_, **kwargs): @@ -115,12 +120,13 @@ def test_policies_configurable(): transport=Mock(send=send), ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -130,10 +136,11 @@ def test_user_agent(): _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -145,7 +152,11 @@ def test_tenant_id(): additionally_allowed_tenants=["*"], ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + + getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -168,20 +179,25 @@ def test_authority(authority): MockCredential(_cache=TokenCache(), authority=authority, transport=transport) -def test_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_empty_cache(get_token_method): """the credential should raise CredentialUnavailableError when the cache is empty""" with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache()).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache()), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache").get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached").get_token("scope") + getattr(SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache").get_token("scope") + getattr( + SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache"), + get_token_method, + )("scope") -def test_no_matching_account_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_username(get_token_method): """one cached account, username specified, username doesn't match -> credential should raise""" upn = "spam@eggs" @@ -190,13 +206,14 @@ def test_no_matching_account_for_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, username="not" + upn).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=cache, username="not" + upn), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message -def test_no_matching_account_for_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant(get_token_method): """one cached account, tenant specified, tenant doesn't match -> credential should raise""" upn = "spam@eggs" @@ -205,13 +222,14 @@ def test_no_matching_account_for_tenant(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant).get_token("scope") + getattr(SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not-" + tenant in ex.value.message -def test_no_matching_account_for_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant_and_username(get_token_method): """one cached account, tenant and username specified, neither match -> credential should raise""" upn = "spam@eggs" @@ -220,13 +238,16 @@ def test_no_matching_account_for_tenant_and_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn).get_token("scope") + getattr( + SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn), get_token_method + )("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message and "not-" + tenant in ex.value.message -def test_no_matching_account_for_tenant_or_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_matching_account_for_tenant_or_username(get_token_method): """two cached accounts, username and tenant specified, one account matches each -> credential should raise""" refresh_token_a = "refresh-token-a" @@ -243,18 +264,19 @@ def test_no_matching_account_for_tenant_or_username(): credential = SharedTokenCacheCredential(username=upn_a, tenant_id=tenant_b, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_a in ex.value.message and tenant_b in ex.value.message credential = SharedTokenCacheCredential(username=upn_b, tenant_id=tenant_a, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_b in ex.value.message and tenant_a in ex.value.message -def test_single_account_matching_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_username(get_token_method): """one cached account, username specified, username matches -> credential should auth that account""" upn = "spam@eggs" @@ -269,11 +291,12 @@ def test_single_account_matching_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account_matching_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_tenant(get_token_method): """one cached account, tenant specified, tenant matches -> credential should auth that account""" tenant_id = "tenant-id" @@ -288,11 +311,12 @@ def test_single_account_matching_tenant(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account_matching_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account_matching_tenant_and_username(get_token_method): """one cached account, tenant and username specified, both match -> credential should auth that account""" upn = "spam@eggs" @@ -308,11 +332,12 @@ def test_single_account_matching_tenant_and_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id, username=upn) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_single_account(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_single_account(get_token_method): """one cached account, no username specified -> credential should auth that account""" refresh_token = "refresh-token" @@ -327,11 +352,12 @@ def test_single_account(): ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_no_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_refresh_token(get_token_method): """one cached account, account has no refresh token -> credential should raise""" account = get_account_event(uid="uid_a", utid="utid", username="spam@eggs", refresh_token=None) @@ -341,14 +367,15 @@ def test_no_refresh_token(): credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username="not@cache") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_two_accounts_no_username_or_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_no_username_or_tenant(get_token_method): """two cached accounts, no username or tenant specified -> credential should raise""" upn_a = "a@foo" @@ -363,10 +390,11 @@ def test_two_accounts_no_username_or_tenant(): # two users in the cache, no username specified -> CredentialUnavailableError credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError, match=MULTIPLE_ACCOUNTS) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_two_accounts_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_username_specified(get_token_method): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" @@ -383,11 +411,12 @@ def test_two_accounts_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_two_accounts_tenant_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_tenant_specified(get_token_method): """two cached accounts, tenant specified, one account matches -> credential should auth that account""" scope = "scope" @@ -405,11 +434,12 @@ def test_two_accounts_tenant_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_two_accounts_tenant_and_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_two_accounts_tenant_and_username_specified(get_token_method): """two cached accounts, tenant and username specified, one account matches both -> credential should auth that account""" scope = "scope" @@ -427,11 +457,12 @@ def test_two_accounts_tenant_and_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_token -def test_same_username_different_tenants(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_same_username_different_tenants(get_token_method): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" @@ -450,7 +481,7 @@ def test_same_username_different_tenants(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert upn in ex.value.message @@ -462,7 +493,7 @@ def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = validating_transport( @@ -470,11 +501,12 @@ def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_b -def test_same_tenant_different_usernames(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_same_tenant_different_usernames(get_token_method): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" @@ -493,7 +525,7 @@ def test_same_tenant_different_usernames(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message @@ -505,7 +537,7 @@ def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = validating_transport( @@ -513,11 +545,12 @@ def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token_a -def test_authority_aliases(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_aliases(get_token_method): """the credential should use a refresh token valid for any known alias of its authority""" expected_access_token = "access-token" @@ -536,7 +569,7 @@ def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # it should be acceptable for every known alias of this authority @@ -546,11 +579,12 @@ def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=alias, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authority_with_no_known_alias(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_with_no_known_alias(get_token_method): """given an appropriate token, an authority with no known aliases should work""" authority = "unknown.authority" @@ -563,11 +597,12 @@ def test_authority_with_no_known_alias(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authority_environment_variable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authority_environment_variable(get_token_method): """the credential should accept an authority by environment variable when none is otherwise specified""" authority = "localhost" @@ -581,11 +616,12 @@ def test_authority_environment_variable(): ) with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_authentication_record_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_empty_cache(get_token_method): record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") def send(request, **kwargs): @@ -601,10 +637,11 @@ def test_authentication_record_empty_cache(): ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_authentication_record_no_match(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_no_match(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -632,10 +669,11 @@ def test_authentication_record_no_match(): credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_authentication_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -658,11 +696,12 @@ def test_authentication_record(): ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_auth_record_multiple_accounts_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_auth_record_multiple_accounts_for_username(get_token_method): tenant_id = "tenant-id" client_id = "client-id" authority = "localhost" @@ -695,11 +734,12 @@ def test_auth_record_multiple_accounts_for_username(): ) credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_access_token -def test_writes_to_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_writes_to_cache(get_token_method): """the credential should write tokens it acquires to the cache""" scope = "scope" @@ -731,14 +771,14 @@ def test_writes_to_cache(): ], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # access token should be in the cache, and another instance should retrieve it credential = SharedTokenCacheCredential( _cache=cache, transport=Mock(send=Mock(side_effect=Exception("the credential should return a cached token"))) ) - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == expected_access_token # and the credential should have updated the cached refresh token @@ -748,14 +788,15 @@ def test_writes_to_cache(): responses=[mock_response(json_payload=build_aad_response(access_token=second_access_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = credential.get_token("some other " + scope) + token = getattr(credential, get_token_method)("some other " + scope) assert token.token == second_access_token # verify the credential didn't add a new cache entry assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 1 -def test_initialization(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_initialization(get_token_method): """the credential should attempt to load the cache when it's needed and no cache has been established.""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -765,15 +806,16 @@ def test_initialization(): assert mock_cache_loader.call_count == 0 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 2 -def test_initialization_with_cache_options(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_initialization_with_cache_options(get_token_method): """the credential should use user-supplied persistence options""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -781,21 +823,25 @@ def test_initialization_with_cache_options(): credential = SharedTokenCacheCredential(cache_persistence_options=options) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 args, _ = mock_cache_loader.call_args assert args[0] == options assert args[1] is False # is_cae is False. with pytest.raises(CredentialUnavailableError): - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert mock_cache_loader.call_count == 2 args, _ = mock_cache_loader.call_args assert args[0] == options assert args[1] is True # is_cae is True. -def test_authentication_record_authenticating_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_authentication_record_authenticating_tenant(get_token_method): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" @@ -812,12 +858,13 @@ def test_authentication_record_authenticating_tenant(): authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport ) with pytest.raises(CredentialUnavailableError): - credential.get_token("scope") # this raises because the cache is empty + getattr(credential, get_token_method)("scope") # this raises because the cache is empty assert transport.send.called -def test_client_capabilities(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_client_capabilities(get_token_method): """the credential should configure MSAL for capability CP1 only if enable_cae is passed.""" def send(request, **kwargs): @@ -834,20 +881,24 @@ def test_client_capabilities(): with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication: with pytest.raises(ClientAuthenticationError): # (cache is empty) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] is None with pytest.raises(ClientAuthenticationError): - credential.get_token("scope", enable_cae=True) + kwargs = {"enable_cae": True} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert PublicClientApplication.call_count == 2 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] -def test_within_dac_error(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_within_dac_error(get_token_method): def send(request, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -862,11 +913,12 @@ def test_within_dac_error(): within_dac.set(True) with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication: with pytest.raises(CredentialUnavailableError): # (cache is empty) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") within_dac.set(False) -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should pass any claims challenge to MSAL token acquisition APIs""" expected_claims = '{"access_token": {"essential": "true"}' @@ -882,14 +934,18 @@ def test_claims_challenge(): transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache()) with patch("azure.identity._credentials.silent.PublicClientApplication", lambda *_, **__: msal_app): - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args assert kwargs["claims_challenge"] == expected_claims -def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication(get_token_method): default_tenant = "organizations" first_token = "***" second_tenant = "second-tenant" @@ -918,21 +974,28 @@ def test_multitenant_authentication(): credential = SharedTokenCacheCredential( authority=authority, transport=Mock(send=send), _cache=cache, additionally_allowed_tenants=["*"] ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token -def test_multitenant_authentication_auth_record(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_auth_record(get_token_method): default_tenant = "organizations" first_token = "***" second_tenant = "second-tenant" @@ -972,17 +1035,23 @@ def test_multitenant_authentication_auth_record(): _cache=cache, additionally_allowed_tenants=["*"], ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == first_token @@ -1017,7 +1086,8 @@ def populated_cache(*accounts): return cache -def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_multitenant_authentication_not_allowed(get_token_method): default_tenant = "organizations" expected_token = "***" @@ -1048,12 +1118,18 @@ def test_multitenant_authentication_not_allowed(): credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = credential.get_token("scope", tenant_id="some tenant") + kwargs = {"tenant_id": "some_tenant"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 1982b98ccf7..012e965d627 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -21,7 +21,7 @@ from azure.identity._internal.user_agent import USER_AGENT from msal import TokenCache import pytest -from helpers import build_aad_response, id_token_claims, mock_response, Request +from helpers import build_aad_response, id_token_claims, mock_response, Request, GET_TOKEN_METHODS from helpers_async import async_validating_transport, AsyncMockTransport from test_shared_cache_credential import get_account_event, populated_cache @@ -32,16 +32,18 @@ def test_supported(): @pytest.mark.asyncio -async def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = SharedTokenCacheCredential(_cache=TokenCache()) with pytest.raises(ValueError): - await credential.get_token() + await getattr(credential, get_token_method)() @pytest.mark.asyncio -async def test_close(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_close(get_token_method): async def send(*_, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -54,7 +56,7 @@ async def test_close(): ) # the credential doesn't open a transport session before one is needed, so we send a request - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") await credential.close() @@ -62,7 +64,8 @@ async def test_close(): @pytest.mark.asyncio -async def test_context_manager(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_context_manager(get_token_method): async def send(*_, **kwargs): # ensure the `claims` and `tenant_id` keywords from credential's `get_token` method don't make it to transport assert "claims" not in kwargs @@ -76,14 +79,14 @@ async def test_context_manager(): # async with before initialization: credential should call __aexit__ but not __aenter__ async with credential: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert transport.__aenter__.call_count == 0 assert transport.__aexit__.call_count == 1 # async with after initialization: credential should call __aenter__ and __aexit__ async with credential: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert transport.__aenter__.call_count == 1 assert transport.__aexit__.call_count == 2 @@ -105,7 +108,8 @@ async def test_context_manager_no_cache(): @pytest.mark.asyncio -async def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) async def send(*_, **kwargs): @@ -120,13 +124,14 @@ async def test_policies_configurable(): transport=Mock(send=send), ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert policy.on_request.called @pytest.mark.asyncio -async def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_user_agent(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -136,11 +141,12 @@ async def test_user_agent(): _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_tenant_id(get_token_method): transport = async_validating_transport( requests=[Request(required_headers={"User-Agent": USER_AGENT})], responses=[mock_response(json_payload=build_aad_response(access_token="**"))], @@ -152,7 +158,10 @@ async def test_tenant_id(): additionally_allowed_tenants=["*"], ) - await credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + await getattr(credential, get_token_method)("scope", **kwargs) @pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @@ -176,22 +185,26 @@ def test_authority(authority): @pytest.mark.asyncio -async def test_empty_cache(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_empty_cache(get_token_method): """the credential should raise CredentialUnavailableError when the cache is empty""" with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache()).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache()), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache").get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache(), username="not@cache"), get_token_method)("scope") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached").get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached"), get_token_method)( + "scope" + ) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): credential = SharedTokenCacheCredential(_cache=TokenCache(), tenant_id="not-cached", username="not@cache") - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_no_matching_account_for_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_username(get_token_method): """one cached account, username specified, username doesn't match -> credential should raise""" upn = "spam@eggs" @@ -200,14 +213,15 @@ async def test_no_matching_account_for_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, username="not" + upn).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=cache, username="not" + upn), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant(get_token_method): """one cached account, tenant specified, tenant doesn't match -> credential should raise""" upn = "spam@eggs" @@ -216,14 +230,15 @@ async def test_no_matching_account_for_tenant(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant).get_token("scope") + await getattr(SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant), get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not-" + tenant in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant_and_username(get_token_method): """one cached account, tenant and username specified, neither match -> credential should raise""" upn = "spam@eggs" @@ -232,16 +247,18 @@ async def test_no_matching_account_for_tenant_and_username(): cache = populated_cache(account) with pytest.raises(CredentialUnavailableError) as ex: - await SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn).get_token( - "scope" - ) + await getattr( + SharedTokenCacheCredential(_cache=cache, tenant_id="not-" + tenant, username="not" + upn), + get_token_method, + )("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert "not" + upn in ex.value.message and "not-" + tenant in ex.value.message @pytest.mark.asyncio -async def test_no_matching_account_for_tenant_or_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_matching_account_for_tenant_or_username(get_token_method): """two cached accounts, username and tenant specified, one account matches each -> credential should raise""" refresh_token_a = "refresh-token-a" @@ -258,19 +275,20 @@ async def test_no_matching_account_for_tenant_or_username(): credential = SharedTokenCacheCredential(username=upn_a, tenant_id=tenant_b, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_a in ex.value.message and tenant_b in ex.value.message credential = SharedTokenCacheCredential(username=upn_b, tenant_id=tenant_a, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(NO_MATCHING_ACCOUNTS[: NO_MATCHING_ACCOUNTS.index("{")]) assert upn_b in ex.value.message and tenant_a in ex.value.message @pytest.mark.asyncio -async def test_single_account_matching_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_username(get_token_method): """one cached account, username specified, username matches -> credential should auth that account""" upn = "spam@eggs" @@ -285,12 +303,13 @@ async def test_single_account_matching_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account_matching_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_tenant(get_token_method): """one cached account, tenant specified, tenant matches -> credential should auth that account""" tenant_id = "tenant-id" @@ -305,12 +324,13 @@ async def test_single_account_matching_tenant(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account_matching_tenant_and_username(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account_matching_tenant_and_username(get_token_method): """one cached account, tenant and username specified, both match -> credential should auth that account""" upn = "spam@eggs" @@ -326,12 +346,13 @@ async def test_single_account_matching_tenant_and_username(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport, tenant_id=tenant_id, username=upn) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_single_account(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_single_account(get_token_method): """one cached account, no username specified -> credential should auth that account""" refresh_token = "refresh-token" @@ -346,12 +367,13 @@ async def test_single_account(): ) credential = SharedTokenCacheCredential(_cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_no_refresh_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_refresh_token(get_token_method): """one cached account, account has no refresh token -> credential should raise""" account = get_account_event(uid="uid_a", utid="utid", username="spam@eggs", refresh_token=None) @@ -361,15 +383,16 @@ async def test_no_refresh_token(): credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username="not@cache") with pytest.raises(CredentialUnavailableError, match=NO_ACCOUNTS): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_two_accounts_no_username_or_tenant(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_no_username_or_tenant(get_token_method): """two cached accounts, no username or tenant specified -> credential should raise""" upn_a = "a@foo" @@ -384,11 +407,12 @@ async def test_two_accounts_no_username_or_tenant(): # two users in the cache, no username specified -> CredentialUnavailableError credential = SharedTokenCacheCredential(_cache=cache, transport=transport) with pytest.raises(ClientAuthenticationError, match=MULTIPLE_ACCOUNTS) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") @pytest.mark.asyncio -async def test_two_accounts_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_username_specified(get_token_method): """two cached accounts, username specified, one account matches -> credential should auth that account""" scope = "scope" @@ -405,12 +429,13 @@ async def test_two_accounts_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_two_accounts_tenant_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_tenant_specified(get_token_method): """two cached accounts, tenant specified, one account matches -> credential should auth that account""" scope = "scope" @@ -428,12 +453,13 @@ async def test_two_accounts_tenant_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_two_accounts_tenant_and_username_specified(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_two_accounts_tenant_and_username_specified(get_token_method): """two cached accounts, tenant and username specified, one account matches both -> credential should auth that account""" scope = "scope" @@ -451,12 +477,13 @@ async def test_two_accounts_tenant_and_username_specified(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_id, username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == expected_token @pytest.mark.asyncio -async def test_same_username_different_tenants(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_same_username_different_tenants(get_token_method): """two cached accounts, same username, different tenants""" access_token_a = "access-token-a" @@ -475,7 +502,7 @@ async def test_same_username_different_tenants(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(username=upn, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert upn in ex.value.message @@ -487,7 +514,7 @@ async def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = async_validating_transport( @@ -495,12 +522,13 @@ async def test_same_username_different_tenants(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))], ) credential = SharedTokenCacheCredential(tenant_id=tenant_b, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_b @pytest.mark.asyncio -async def test_same_tenant_different_usernames(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_same_tenant_different_usernames(get_token_method): """two cached accounts, same tenant, different usernames""" access_token_a = "access-token-a" @@ -519,7 +547,7 @@ async def test_same_tenant_different_usernames(): transport = Mock(side_effect=Exception()) # (so it shouldn't use the network) credential = SharedTokenCacheCredential(tenant_id=tenant_id, _cache=cache, transport=transport) with pytest.raises(CredentialUnavailableError) as ex: - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert ex.value.message.startswith(MULTIPLE_MATCHING_ACCOUNTS[: MULTIPLE_MATCHING_ACCOUNTS.index("{")]) assert tenant_id in ex.value.message @@ -531,7 +559,7 @@ async def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_b, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a transport = async_validating_transport( @@ -539,12 +567,13 @@ async def test_same_tenant_different_usernames(): responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))], ) credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport) - token = await credential.get_token(scope) + token = await getattr(credential, get_token_method)(scope) assert token.token == access_token_a @pytest.mark.asyncio -async def test_authority_aliases(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_aliases(get_token_method): """the credential should use a refresh token valid for any known alias of its authority""" expected_access_token = "access-token" @@ -563,7 +592,7 @@ async def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token # it should also be acceptable for every known alias of this authority @@ -573,12 +602,13 @@ async def test_authority_aliases(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=alias, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_authority_with_no_known_alias(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_with_no_known_alias(get_token_method): """given an appropriate token, an authority with no known aliases should work""" authority = "unknown.authority" @@ -591,12 +621,13 @@ async def test_authority_with_no_known_alias(): responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) credential = SharedTokenCacheCredential(authority=authority, _cache=cache, transport=transport) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_authority_environment_variable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_authority_environment_variable(get_token_method): """the credential should accept an authority by environment variable when none is otherwise specified""" authority = "localhost" @@ -610,12 +641,13 @@ async def test_authority_environment_variable(): ) with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): credential = SharedTokenCacheCredential(transport=transport, _cache=cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_access_token @pytest.mark.asyncio -async def test_initialization(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_initialization(get_token_method): """the credential should attempt to load the cache when it's needed and no cache has been established.""" with patch("azure.identity._persistent_cache._get_persistence") as mock_cache_loader: @@ -625,16 +657,17 @@ async def test_initialization(): assert mock_cache_loader.call_count == 0 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 with pytest.raises(CredentialUnavailableError, match="Shared token cache unavailable"): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 2 @pytest.mark.asyncio -async def test_initialization_with_cache_options(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_initialization_with_cache_options(get_token_method): """the credential should use user-supplied persistence options""" with patch("azure.identity._internal.shared_token_cache._load_persistent_cache") as mock_cache_loader: @@ -642,12 +675,13 @@ async def test_initialization_with_cache_options(): credential = SharedTokenCacheCredential(cache_persistence_options=options) with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") + await getattr(credential, get_token_method)("scope") assert mock_cache_loader.call_count == 1 @pytest.mark.asyncio -async def test_multitenant_authentication(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication(get_token_method): first_token = "***" second_tenant = "second-tenant" second_token = first_token * 2 @@ -674,22 +708,29 @@ async def test_multitenant_authentication(): credential = SharedTokenCacheCredential( authority=authority, transport=Mock(send=send), _cache=cache, additionally_allowed_tenants=["*"] ) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token - token = await credential.get_token("scope", tenant_id="organizations") + kwargs = {"tenant_id": "organizations"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == first_token - token = await credential.get_token("scope", tenant_id=second_tenant) + kwargs = {"tenant_id": second_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == second_token # should still default to the first tenant - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == first_token @pytest.mark.asyncio -async def test_multitenant_authentication_not_allowed(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_multitenant_authentication_not_allowed(get_token_method): default_tenant = "organizations" expected_token = "***" @@ -715,12 +756,18 @@ async def test_multitenant_authentication_not_allowed(): credential = SharedTokenCacheCredential(authority=authority, transport=Mock(send=send), _cache=cache) - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == expected_token - token = await credential.get_token("scope", tenant_id=default_tenant) + kwargs = {"tenant_id": default_tenant} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token + kwargs = {"tenant_id": "some_tenant"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} with patch.dict("os.environ", {EnvironmentVariables.AZURE_IDENTITY_DISABLE_MULTITENANTAUTH: "true"}): - token = await credential.get_token("scope", tenant_id="some tenant") + token = await getattr(credential, get_token_method)("scope", **kwargs) assert token.token == expected_token diff --git a/sdk/identity/azure-identity/tests/test_username_password_credential.py b/sdk/identity/azure-identity/tests/test_username_password_credential.py index ec86c117424..3ce95488fc3 100644 --- a/sdk/identity/azure-identity/tests/test_username_password_credential.py +++ b/sdk/identity/azure-identity/tests/test_username_password_credential.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from unittest.mock import Mock, patch + from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import UsernamePasswordCredential from azure.identity._internal.user_agent import USER_AGENT @@ -15,13 +17,9 @@ from helpers import ( mock_response, Request, validating_transport, + GET_TOKEN_METHODS, ) -try: - from unittest.mock import Mock, patch -except ImportError: # python < 3.3 - from mock import Mock, patch # type: ignore - def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" @@ -36,15 +34,17 @@ def test_tenant_id_validation(): UsernamePasswordCredential("client-id", "username", "password", tenant_id=tenant) -def test_no_scopes(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_no_scopes(get_token_method): """The credential should raise when get_token is called with no scopes""" credential = UsernamePasswordCredential("client-id", "username", "password") with pytest.raises(ValueError): - credential.get_token() + getattr(credential, get_token_method)() -def test_policies_configurable(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_policies_configurable(get_token_method): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) transport = validating_transport( @@ -54,12 +54,13 @@ def test_policies_configurable(): ) credential = UsernamePasswordCredential("client-id", "username", "password", policies=[policy], transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") assert policy.on_request.called -def test_user_agent(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_user_agent(get_token_method): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[get_discovery_response()] * 2 @@ -68,10 +69,11 @@ def test_user_agent(): credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) - credential.get_token("scope") + getattr(credential, get_token_method)("scope") -def test_tenant_id(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_tenant_id(get_token_method): transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"User-Agent": USER_AGENT})], responses=[get_discovery_response()] * 2 @@ -82,10 +84,14 @@ def test_tenant_id(): "client-id", "username", "password", transport=transport, additionally_allowed_tenants=["*"] ) - credential.get_token("scope", tenant_id="tenant_id") + kwargs = {"tenant_id": "tenant_id"} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) -def test_username_password_credential(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_username_password_credential(get_token_method): expected_token = "access-token" client_id = "client-id" transport = validating_transport( @@ -110,11 +116,12 @@ def test_username_password_credential(): disable_instance_discovery=True, # kwargs are passed to MSAL; this one prevents a Microsoft Entra verification request ) - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == expected_token -def test_authenticate(): +@pytest.mark.parametrize("get_token_method", ["get_token_info"]) +def test_authenticate(get_token_method): client_id = "client-id" environment = "localhost" issuer = "https://" + environment @@ -158,7 +165,7 @@ def test_authenticate(): assert record.username == username # credential should have a cached access token for the scope passed to authenticate - token = credential.get_token(scope) + token = getattr(credential, get_token_method)(scope) assert token.token == access_token @@ -182,7 +189,8 @@ def test_client_capabilities(): assert kwargs["client_capabilities"] == ["CP1"] -def test_claims_challenge(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_challenge(get_token_method): """get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs""" msal_acquire_token_result = dict( @@ -202,7 +210,10 @@ def test_claims_challenge(): args, kwargs = msal_app.acquire_token_by_username_password.call_args assert kwargs["claims_challenge"] == expected_claims - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_by_username_password.call_count == 2 args, kwargs = msal_app.acquire_token_by_username_password.call_args @@ -210,7 +221,10 @@ def test_claims_challenge(): msal_app.get_accounts.return_value = [{"home_account_id": credential._auth_record.home_account_id}] msal_app.acquire_token_silent_with_error.return_value = msal_acquire_token_result - credential.get_token("scope", claims=expected_claims) + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + getattr(credential, get_token_method)("scope", **kwargs) assert msal_app.acquire_token_silent_with_error.call_count == 1 args, kwargs = msal_app.acquire_token_silent_with_error.call_args diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index a2e5b37e592..becfe82af14 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -4,6 +4,8 @@ # ------------------------------------ import sys import time +from unittest import mock +from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -12,14 +14,9 @@ from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import pytest -from urllib.parse import urlparse from helpers import build_aad_response, mock_response, Request, validating_transport -try: - from unittest import mock -except ImportError: # python < 3.3 - import mock GET_REFRESH_TOKEN = VisualStudioCodeCredential.__module__ + ".get_refresh_token" GET_USER_SETTINGS = VisualStudioCodeCredential.__module__ + ".get_user_settings" diff --git a/sdk/identity/azure-identity/tests/test_workload_identity_credential.py b/sdk/identity/azure-identity/tests/test_workload_identity_credential.py index a744c7e36b7..1db0874a77c 100644 --- a/sdk/identity/azure-identity/tests/test_workload_identity_credential.py +++ b/sdk/identity/azure-identity/tests/test_workload_identity_credential.py @@ -4,9 +4,10 @@ # ------------------------------------ from unittest.mock import mock_open, MagicMock, patch +import pytest from azure.identity import WorkloadIdentityCredential -from helpers import mock_response, build_aad_response +from helpers import mock_response, build_aad_response, GET_TOKEN_METHODS def test_workload_identity_credential_initialize(): @@ -18,7 +19,8 @@ def test_workload_identity_credential_initialize(): ) -def test_workload_identity_credential_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_workload_identity_credential_get_token(get_token_method): tenant_id = "tenant-id" client_id = "client-id" access_token = "foo" @@ -38,7 +40,7 @@ def test_workload_identity_credential_get_token(): open_mock = mock_open(read_data=assertion) with patch("builtins.open", open_mock): - token = credential.get_token("scope") + token = getattr(credential, get_token_method)("scope") assert token.token == access_token open_mock.assert_called_once_with(token_file_path, encoding="utf-8") diff --git a/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py b/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py index cbda554886a..fc6f3e8c6cb 100644 --- a/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_workload_identity_credential_async.py @@ -7,7 +7,7 @@ from unittest.mock import mock_open, patch, MagicMock import pytest from azure.identity.aio import WorkloadIdentityCredential -from helpers import mock_response, build_aad_response +from helpers import mock_response, build_aad_response, GET_TOKEN_METHODS def test_workload_identity_credential_initialize(): @@ -20,7 +20,8 @@ def test_workload_identity_credential_initialize(): @pytest.mark.asyncio -async def test_workload_identity_credential_get_token(): +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_workload_identity_credential_get_token(get_token_method): tenant_id = "tenant-id" client_id = "client-id" access_token = "foo" @@ -40,7 +41,7 @@ async def test_workload_identity_credential_get_token(): open_mock = mock_open(read_data=assertion) with patch("builtins.open", open_mock): - token = await credential.get_token("scope") + token = await getattr(credential, get_token_method)("scope") assert token.token == access_token open_mock.assert_called_once_with(token_file_path, encoding="utf-8") From 754bb7fda3b38fc452dd1bb8133acc88018eced8 Mon Sep 17 00:00:00 2001 From: Waqas Javed <7674577+w-javed@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:08:10 -0700 Subject: [PATCH 09/17] activate-gen-ai-CI (#37432) * activate-gen-ai-CI * activate-gen-ai-CI * activate-azure-gen-ai-pkg-in-CI * activate-azure-gen-ai-pkg-in-CI --- sdk/ai/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/ai/ci.yml b/sdk/ai/ci.yml index 5780d354da0..6ee45adc4d5 100644 --- a/sdk/ai/ci.yml +++ b/sdk/ai/ci.yml @@ -48,8 +48,8 @@ extends: Artifacts: - name: azure-ai-inference safeName: azureaiinference + - name: azure-ai-generative + safeName: azureaigenerative # These packages are deprecated: - #- name: azure-ai-generative - # safeName: azureaigenerative #- name: azure-ai-resources # safeName: azureairesources From 6db5edb8ce926164eab1f4df15af3ab8e7a8b9c5 Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Tue, 17 Sep 2024 19:47:35 -0700 Subject: [PATCH 10/17] [AutoRelease] t2-cosmosdb-2024-09-11-21177(can only be merged by SDK owner) (#37286) * code and test * update-testcase * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md --------- Co-authored-by: azure-sdk Co-authored-by: ChenxiJiang333 Co-authored-by: ChenxiJiang333 <119990644+ChenxiJiang333@users.noreply.github.com> --- sdk/cosmos/azure-mgmt-cosmosdb/CHANGELOG.md | 12 + sdk/cosmos/azure-mgmt-cosmosdb/_meta.json | 6 +- .../azure/mgmt/cosmosdb/_configuration.py | 4 +- .../cosmosdb/_cosmos_db_management_client.py | 5 +- .../azure/mgmt/cosmosdb/_serialization.py | 2 + .../azure/mgmt/cosmosdb/_vendor.py | 16 - .../azure/mgmt/cosmosdb/_version.py | 2 +- .../azure/mgmt/cosmosdb/aio/_configuration.py | 4 +- .../aio/_cosmos_db_management_client.py | 5 +- .../_cassandra_clusters_operations.py | 146 ++-- .../_cassandra_data_centers_operations.py | 78 +- .../_cassandra_resources_operations.py | 238 +++--- .../aio/operations/_collection_operations.py | 10 +- .../_collection_partition_operations.py | 8 +- ..._collection_partition_region_operations.py | 6 +- .../_collection_region_operations.py | 6 +- .../_database_account_region_operations.py | 6 +- .../_database_accounts_operations.py | 183 +++-- .../aio/operations/_database_operations.py | 10 +- .../_gremlin_resources_operations.py | 257 +++--- .../aio/operations/_locations_operations.py | 9 +- .../_mongo_db_resources_operations.py | 351 ++++---- .../_notebook_workspaces_operations.py | 102 ++- .../cosmosdb/aio/operations/_operations.py | 6 +- .../_partition_key_range_id_operations.py | 6 +- ...artition_key_range_id_region_operations.py | 6 +- .../aio/operations/_percentile_operations.py | 6 +- .../_percentile_source_target_operations.py | 6 +- .../_percentile_target_operations.py | 6 +- ...private_endpoint_connections_operations.py | 59 +- .../_private_link_resources_operations.py | 9 +- ...restorable_database_accounts_operations.py | 11 +- ...restorable_gremlin_databases_operations.py | 6 +- .../_restorable_gremlin_graphs_operations.py | 6 +- ...restorable_gremlin_resources_operations.py | 6 +- ...storable_mongodb_collections_operations.py | 6 +- ...restorable_mongodb_databases_operations.py | 6 +- ...restorable_mongodb_resources_operations.py | 6 +- .../_restorable_sql_containers_operations.py | 6 +- .../_restorable_sql_databases_operations.py | 6 +- .../_restorable_sql_resources_operations.py | 6 +- .../_restorable_table_resources_operations.py | 6 +- .../_restorable_tables_operations.py | 6 +- .../aio/operations/_service_operations.py | 57 +- .../operations/_sql_resources_operations.py | 524 +++++++----- .../operations/_table_resources_operations.py | 142 ++-- .../azure/mgmt/cosmosdb/models/__init__.py | 6 + .../_cosmos_db_management_client_enums.py | 3 +- .../azure/mgmt/cosmosdb/models/_models_py3.py | 151 +++- .../_cassandra_clusters_operations.py | 166 ++-- .../_cassandra_data_centers_operations.py | 88 +- .../_cassandra_resources_operations.py | 270 ++++--- .../operations/_collection_operations.py | 16 +- .../_collection_partition_operations.py | 12 +- ..._collection_partition_region_operations.py | 8 +- .../_collection_region_operations.py | 8 +- .../_database_account_region_operations.py | 8 +- .../_database_accounts_operations.py | 219 ++--- .../operations/_database_operations.py | 16 +- .../_gremlin_resources_operations.py | 291 ++++--- .../operations/_locations_operations.py | 13 +- .../_mongo_db_resources_operations.py | 401 +++++---- .../_notebook_workspaces_operations.py | 116 ++- .../mgmt/cosmosdb/operations/_operations.py | 8 +- .../_partition_key_range_id_operations.py | 8 +- ...artition_key_range_id_region_operations.py | 8 +- .../operations/_percentile_operations.py | 8 +- .../_percentile_source_target_operations.py | 8 +- .../_percentile_target_operations.py | 8 +- ...private_endpoint_connections_operations.py | 67 +- .../_private_link_resources_operations.py | 13 +- ...restorable_database_accounts_operations.py | 17 +- ...restorable_gremlin_databases_operations.py | 8 +- .../_restorable_gremlin_graphs_operations.py | 8 +- ...restorable_gremlin_resources_operations.py | 8 +- ...storable_mongodb_collections_operations.py | 8 +- ...restorable_mongodb_databases_operations.py | 8 +- ...restorable_mongodb_resources_operations.py | 8 +- .../_restorable_sql_containers_operations.py | 8 +- .../_restorable_sql_databases_operations.py | 8 +- .../_restorable_sql_resources_operations.py | 8 +- .../_restorable_table_resources_operations.py | 8 +- .../_restorable_tables_operations.py | 8 +- .../operations/_service_operations.py | 65 +- .../operations/_sql_resources_operations.py | 604 ++++++++------ .../operations/_table_resources_operations.py | 160 ++-- .../azure-mgmt-cosmosdb/dev_requirements.txt | 3 +- ...mos_db_cassandra_keyspace_create_update.py | 4 +- .../cosmos_db_cassandra_keyspace_delete.py | 2 +- .../cosmos_db_cassandra_keyspace_get.py | 2 +- .../cosmos_db_cassandra_keyspace_list.py | 2 +- ...cassandra_keyspace_migrate_to_autoscale.py | 2 +- ...a_keyspace_migrate_to_manual_throughput.py | 2 +- ...os_db_cassandra_keyspace_throughput_get.py | 2 +- ...db_cassandra_keyspace_throughput_update.py | 4 +- ...cosmos_db_cassandra_table_create_update.py | 4 +- .../cosmos_db_cassandra_table_delete.py | 2 +- .../cosmos_db_cassandra_table_get.py | 2 +- .../cosmos_db_cassandra_table_list.py | 2 +- ...db_cassandra_table_migrate_to_autoscale.py | 2 +- ...ndra_table_migrate_to_manual_throughput.py | 2 +- ...osmos_db_cassandra_table_throughput_get.py | 2 +- ...os_db_cassandra_table_throughput_update.py | 4 +- ...os_db_collection_get_metric_definitions.py | 2 +- ...py => cosmos_db_collection_get_metrics.py} | 22 +- .../cosmos_db_collection_get_usages.py | 2 +- ...os_db_collection_partition_get_metrics.py} | 33 +- ...smos_db_collection_partition_get_usages.py | 2 +- ...collection_partition_region_get_metrics.py | 47 ++ .../cosmos_db_data_transfer_service_create.py | 4 +- .../cosmos_db_data_transfer_service_delete.py | 2 +- .../cosmos_db_data_transfer_service_get.py | 2 +- ...s_db_database_account_check_name_exists.py | 2 +- .../cosmos_db_database_account_create_max.py | 4 +- .../cosmos_db_database_account_create_min.py | 4 +- .../cosmos_db_database_account_delete.py | 2 +- ...tabase_account_failover_priority_change.py | 4 +- .../cosmos_db_database_account_get.py | 2 +- ...database_account_get_metric_definitions.py | 2 +- .../cosmos_db_database_account_get_metrics.py | 44 + .../cosmos_db_database_account_get_usages.py | 2 +- .../cosmos_db_database_account_list.py | 2 +- ...database_account_list_by_resource_group.py | 2 +- ...atabase_account_list_connection_strings.py | 2 +- ...e_account_list_connection_strings_mongo.py | 2 +- .../cosmos_db_database_account_list_keys.py | 2 +- ...db_database_account_list_read_only_keys.py | 4 +- ...smos_db_database_account_offline_region.py | 6 +- ...osmos_db_database_account_online_region.py | 6 +- .../cosmos_db_database_account_patch.py | 4 +- ...smos_db_database_account_regenerate_key.py | 4 +- ..._db_database_account_region_get_metrics.py | 45 ++ ...smos_db_database_get_metric_definitions.py | 2 +- .../cosmos_db_database_get_metrics.py | 45 ++ .../cosmos_db_database_get_usages.py | 2 +- ...mos_db_graph_api_compute_service_create.py | 4 +- ...mos_db_graph_api_compute_service_delete.py | 2 +- ...cosmos_db_graph_api_compute_service_get.py | 2 +- ...osmos_db_gremlin_database_create_update.py | 4 +- .../cosmos_db_gremlin_database_delete.py | 2 +- .../cosmos_db_gremlin_database_get.py | 2 +- .../cosmos_db_gremlin_database_list.py | 2 +- ...b_gremlin_database_migrate_to_autoscale.py | 2 +- ...n_database_migrate_to_manual_throughput.py | 2 +- ...smos_db_gremlin_database_throughput_get.py | 2 +- ...s_db_gremlin_database_throughput_update.py | 4 +- ...mos_db_gremlin_graph_backup_information.py | 4 +- .../cosmos_db_gremlin_graph_create_update.py | 4 +- .../cosmos_db_gremlin_graph_delete.py | 2 +- .../cosmos_db_gremlin_graph_get.py | 2 +- .../cosmos_db_gremlin_graph_list.py | 2 +- ...s_db_gremlin_graph_migrate_to_autoscale.py | 2 +- ...mlin_graph_migrate_to_manual_throughput.py | 2 +- .../cosmos_db_gremlin_graph_throughput_get.py | 2 +- ...smos_db_gremlin_graph_throughput_update.py | 4 +- .../cosmos_db_location_get.py | 2 +- .../cosmos_db_location_list.py | 2 +- ...mos_db_managed_cassandra_cluster_create.py | 4 +- ...db_managed_cassandra_cluster_deallocate.py | 2 +- ...mos_db_managed_cassandra_cluster_delete.py | 2 +- ...cosmos_db_managed_cassandra_cluster_get.py | 2 +- ...assandra_cluster_list_by_resource_group.py | 2 +- ..._cassandra_cluster_list_by_subscription.py | 2 +- ...smos_db_managed_cassandra_cluster_patch.py | 4 +- ...smos_db_managed_cassandra_cluster_start.py | 2 +- .../cosmos_db_managed_cassandra_command.py | 4 +- ...db_managed_cassandra_data_center_create.py | 4 +- ...db_managed_cassandra_data_center_delete.py | 2 +- ...os_db_managed_cassandra_data_center_get.py | 2 +- ...s_db_managed_cassandra_data_center_list.py | 2 +- ..._db_managed_cassandra_data_center_patch.py | 4 +- .../cosmos_db_managed_cassandra_status.py | 2 +- ...terialized_views_builder_service_create.py | 4 +- ...terialized_views_builder_service_delete.py | 2 +- ..._materialized_views_builder_service_get.py | 2 +- ..._mongo_db_collection_backup_information.py | 4 +- .../cosmos_db_mongo_db_collection_delete.py | 2 +- .../cosmos_db_mongo_db_collection_get.py | 2 +- .../cosmos_db_mongo_db_collection_list.py | 2 +- ...ongo_db_collection_migrate_to_autoscale.py | 2 +- ...collection_migrate_to_manual_throughput.py | 2 +- ...s_db_mongo_db_collection_throughput_get.py | 2 +- ...b_mongo_db_collection_throughput_update.py | 4 +- .../cosmos_db_mongo_db_database_delete.py | 2 +- .../cosmos_db_mongo_db_database_get.py | 2 +- .../cosmos_db_mongo_db_database_list.py | 2 +- ..._mongo_db_database_migrate_to_autoscale.py | 2 +- ...b_database_migrate_to_manual_throughput.py | 2 +- ...mos_db_mongo_db_database_throughput_get.py | 2 +- ..._db_mongo_db_database_throughput_update.py | 4 +- ..._mongo_db_role_definition_create_update.py | 4 +- ...smos_db_mongo_db_role_definition_delete.py | 2 +- .../cosmos_db_mongo_db_role_definition_get.py | 2 +- ...cosmos_db_mongo_db_role_definition_list.py | 2 +- ..._mongo_db_user_definition_create_update.py | 4 +- ...smos_db_mongo_db_user_definition_delete.py | 2 +- .../cosmos_db_mongo_db_user_definition_get.py | 2 +- ...cosmos_db_mongo_db_user_definition_list.py | 2 +- .../cosmos_db_notebook_workspace_delete.py | 7 +- .../cosmos_db_notebook_workspace_get.py | 7 +- .../cosmos_db_notebook_workspace_list.py | 2 +- ...notebook_workspace_list_connection_info.py | 7 +- ...otebook_workspace_regenerate_auth_token.py | 7 +- .../cosmos_db_notebook_workspace_start.py | 7 +- .../cosmos_db_operations_list.py | 2 +- .../cosmos_db_percentile_get_metrics.py | 44 + ...db_percentile_source_target_get_metrics.py | 46 ++ ...cosmos_db_percentile_target_get_metrics.py | 45 ++ ...s_db_private_endpoint_connection_delete.py | 2 +- ...smos_db_private_endpoint_connection_get.py | 2 +- ...db_private_endpoint_connection_list_get.py | 2 +- ...s_db_private_endpoint_connection_update.py | 4 +- .../cosmos_db_private_link_resource_get.py | 2 +- ...osmos_db_private_link_resource_list_get.py | 2 +- ...cosmos_db_region_collection_get_metrics.py | 47 ++ ...smos_db_restorable_database_account_get.py | 2 +- ...mos_db_restorable_database_account_list.py | 2 +- ...rable_database_account_no_location_list.py | 2 +- ...mos_db_restorable_gremlin_database_list.py | 2 +- ...cosmos_db_restorable_gremlin_graph_list.py | 2 +- ...mos_db_restorable_gremlin_resource_list.py | 2 +- ...s_db_restorable_mongodb_collection_list.py | 2 +- ...mos_db_restorable_mongodb_database_list.py | 2 +- ...mos_db_restorable_mongodb_resource_list.py | 2 +- ...cosmos_db_restorable_sql_container_list.py | 2 +- .../cosmos_db_restorable_sql_database_list.py | 2 +- .../cosmos_db_restorable_sql_resource_list.py | 2 +- .../cosmos_db_restorable_table_list.py | 2 +- ...osmos_db_restorable_table_resource_list.py | 2 +- ..._restore_database_account_create_update.py | 5 +- .../cosmos_db_services_list.py | 2 +- ...sql_client_encryption_key_create_update.py | 4 +- ...cosmos_db_sql_client_encryption_key_get.py | 2 +- ...smos_db_sql_client_encryption_keys_list.py | 2 +- ...mos_db_sql_container_backup_information.py | 4 +- .../cosmos_db_sql_container_create_update.py | 4 +- .../cosmos_db_sql_container_delete.py | 2 +- .../cosmos_db_sql_container_get.py | 2 +- .../cosmos_db_sql_container_list.py | 2 +- ...s_db_sql_container_migrate_to_autoscale.py | 2 +- ..._container_migrate_to_manual_throughput.py | 2 +- .../cosmos_db_sql_container_throughput_get.py | 2 +- ...smos_db_sql_container_throughput_update.py | 4 +- .../cosmos_db_sql_database_create_update.py | 4 +- .../cosmos_db_sql_database_delete.py | 2 +- .../cosmos_db_sql_database_get.py | 2 +- .../cosmos_db_sql_database_list.py | 2 +- ...os_db_sql_database_migrate_to_autoscale.py | 2 +- ...l_database_migrate_to_manual_throughput.py | 2 +- .../cosmos_db_sql_database_throughput_get.py | 2 +- ...osmos_db_sql_database_throughput_update.py | 4 +- ...os_db_sql_role_assignment_create_update.py | 4 +- .../cosmos_db_sql_role_assignment_delete.py | 2 +- .../cosmos_db_sql_role_assignment_get.py | 2 +- .../cosmos_db_sql_role_assignment_list.py | 2 +- ...os_db_sql_role_definition_create_update.py | 4 +- .../cosmos_db_sql_role_definition_delete.py | 2 +- .../cosmos_db_sql_role_definition_get.py | 2 +- .../cosmos_db_sql_role_definition_list.py | 2 +- ...s_db_sql_stored_procedure_create_update.py | 4 +- .../cosmos_db_sql_stored_procedure_delete.py | 2 +- .../cosmos_db_sql_stored_procedure_get.py | 2 +- .../cosmos_db_sql_stored_procedure_list.py | 2 +- .../cosmos_db_sql_trigger_create_update.py | 4 +- .../cosmos_db_sql_trigger_delete.py | 2 +- .../cosmos_db_sql_trigger_get.py | 2 +- .../cosmos_db_sql_trigger_list.py | 2 +- ...sql_user_defined_function_create_update.py | 4 +- ...mos_db_sql_user_defined_function_delete.py | 2 +- ...cosmos_db_sql_user_defined_function_get.py | 2 +- ...osmos_db_sql_user_defined_function_list.py | 2 +- .../cosmos_db_table_backup_information.py | 4 +- .../cosmos_db_table_create_update.py | 4 +- .../cosmos_db_table_delete.py | 2 +- .../generated_samples/cosmos_db_table_get.py | 2 +- .../generated_samples/cosmos_db_table_list.py | 2 +- .../cosmos_db_table_migrate_to_autoscale.py | 2 +- ...s_db_table_migrate_to_manual_throughput.py | 2 +- .../cosmos_db_table_throughput_get.py | 2 +- .../cosmos_db_table_throughput_update.py | 4 +- .../cosmos_dbp_key_range_id_get_metrics.py | 47 ++ ...mos_dbp_key_range_id_region_get_metrics.py | 48 ++ ...db_sql_dedicated_gateway_service_create.py | 4 +- ...db_sql_dedicated_gateway_service_delete.py | 2 +- ...os_db_sql_dedicated_gateway_service_get.py | 2 +- .../generated_tests/conftest.py | 35 + ...anagement_cassandra_clusters_operations.py | 203 +++++ ...ent_cassandra_clusters_operations_async.py | 226 ++++++ ...ement_cassandra_data_centers_operations.py | 147 ++++ ...cassandra_data_centers_operations_async.py | 164 ++++ ...nagement_cassandra_resources_operations.py | 298 +++++++ ...nt_cassandra_resources_operations_async.py | 319 ++++++++ ...mos_db_management_collection_operations.py | 62 ++ ..._management_collection_operations_async.py | 63 ++ ...agement_collection_partition_operations.py | 48 ++ ...t_collection_partition_operations_async.py | 49 ++ ..._collection_partition_region_operations.py | 35 + ...ction_partition_region_operations_async.py | 36 + ...management_collection_region_operations.py | 35 + ...ment_collection_region_operations_async.py | 36 + ...ment_database_account_region_operations.py | 33 + ...atabase_account_region_operations_async.py | 34 + ...management_database_accounts_operations.py | 381 +++++++++ ...ment_database_accounts_operations_async.py | 396 +++++++++ ...osmos_db_management_database_operations.py | 59 ++ ...db_management_database_operations_async.py | 60 ++ ...management_gremlin_resources_operations.py | 339 ++++++++ ...ment_gremlin_resources_operations_async.py | 362 +++++++++ ...smos_db_management_locations_operations.py | 40 + ...b_management_locations_operations_async.py | 41 + ...anagement_mongo_db_resources_operations.py | 440 ++++++++++ ...ent_mongo_db_resources_operations_async.py | 471 +++++++++++ ...nagement_notebook_workspaces_operations.py | 110 +++ ...nt_notebook_workspaces_operations_async.py | 119 +++ .../test_cosmos_db_management_operations.py | 29 + ...t_cosmos_db_management_operations_async.py | 30 + ...ement_partition_key_range_id_operations.py | 35 + ...partition_key_range_id_operations_async.py | 36 + ...artition_key_range_id_region_operations.py | 36 + ...on_key_range_id_region_operations_async.py | 37 + ...mos_db_management_percentile_operations.py | 32 + ..._management_percentile_operations_async.py | 33 + ...ent_percentile_source_target_operations.py | 34 + ...rcentile_source_target_operations_async.py | 35 + ...management_percentile_target_operations.py | 33 + ...ment_percentile_target_operations_async.py | 34 + ...private_endpoint_connections_operations.py | 79 ++ ...e_endpoint_connections_operations_async.py | 88 ++ ...ement_private_link_resources_operations.py | 44 + ...private_link_resources_operations_async.py | 45 ++ ...restorable_database_accounts_operations.py | 52 ++ ...able_database_accounts_operations_async.py | 53 ++ ...restorable_gremlin_databases_operations.py | 31 + ...able_gremlin_databases_operations_async.py | 32 + ...nt_restorable_gremlin_graphs_operations.py | 31 + ...torable_gremlin_graphs_operations_async.py | 32 + ...restorable_gremlin_resources_operations.py | 31 + ...able_gremlin_resources_operations_async.py | 32 + ...storable_mongodb_collections_operations.py | 31 + ...le_mongodb_collections_operations_async.py | 32 + ...restorable_mongodb_databases_operations.py | 31 + ...able_mongodb_databases_operations_async.py | 32 + ...restorable_mongodb_resources_operations.py | 31 + ...able_mongodb_resources_operations_async.py | 32 + ...nt_restorable_sql_containers_operations.py | 31 + ...torable_sql_containers_operations_async.py | 32 + ...ent_restorable_sql_databases_operations.py | 31 + ...storable_sql_databases_operations_async.py | 32 + ...ent_restorable_sql_resources_operations.py | 31 + ...storable_sql_resources_operations_async.py | 32 + ...t_restorable_table_resources_operations.py | 31 + ...orable_table_resources_operations_async.py | 32 + ...management_restorable_tables_operations.py | 31 + ...ment_restorable_tables_operations_async.py | 32 + ...cosmos_db_management_service_operations.py | 71 ++ ..._db_management_service_operations_async.py | 76 ++ ..._db_management_sql_resources_operations.py | 717 ++++++++++++++++ ...nagement_sql_resources_operations_async.py | 762 ++++++++++++++++++ ...b_management_table_resources_operations.py | 172 ++++ ...gement_table_resources_operations_async.py | 185 +++++ sdk/cosmos/azure-mgmt-cosmosdb/setup.py | 1 + .../azure-mgmt-cosmosdb/tests/conftest.py | 35 + ...database_accounts_operations_async_test.py | 29 + ...ement_database_accounts_operations_test.py | 28 + ...mos_db_management_operations_async_test.py | 28 + ...st_cosmos_db_management_operations_test.py | 27 + 366 files changed, 11795 insertions(+), 2671 deletions(-) delete mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_vendor.py rename sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/{cosmos_db_mongo_db_database_create_update.py => cosmos_db_collection_get_metrics.py} (71%) rename sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/{cosmos_db_mongo_db_collection_create_update.py => cosmos_db_collection_partition_get_metrics.py} (58%) create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_region_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_region_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_source_target_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_target_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_region_collection_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_region_get_metrics.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/conftest.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations_async.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/tests/conftest.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_async_test.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_test.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_async_test.py create mode 100644 sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_test.py diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/CHANGELOG.md b/sdk/cosmos/azure-mgmt-cosmosdb/CHANGELOG.md index 31672178316..bec3853dc26 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/CHANGELOG.md +++ b/sdk/cosmos/azure-mgmt-cosmosdb/CHANGELOG.md @@ -1,5 +1,17 @@ # Release History +## 9.6.0 (2024-09-18) + +### Features Added + + - Model `ResourceRestoreParameters` added property `restore_with_ttl_disabled` + - Model `RestoreParameters` added parameter `restore_with_ttl_disabled` in method `__init__` + - Model `RestoreParametersBase` added property `restore_with_ttl_disabled` + - Enum `ServerVersion` added member `SEVEN0` + - Added model `ErrorAdditionalInfo` + - Added model `ErrorDetail` + - Added model `ErrorResponseAutoGenerated` + ## 9.5.1 (2024-06-19) ### Features Added diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/_meta.json b/sdk/cosmos/azure-mgmt-cosmosdb/_meta.json index 6035385fb15..2dc61f33d87 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/_meta.json +++ b/sdk/cosmos/azure-mgmt-cosmosdb/_meta.json @@ -1,11 +1,11 @@ { - "commit": "f1546dc981fa5d164d7ecd13588520457462c22c", + "commit": "3519c80fe510a268f6e59a29ccac8a53fdec15b6", "repository_url": "https://github.com/Azure/azure-rest-api-specs", "autorest": "3.10.2", "use": [ - "@autorest/python@6.13.19", + "@autorest/python@6.19.0", "@autorest/modelerfour@4.27.0" ], - "autorest_command": "autorest specification/cosmos-db/resource-manager/readme.md --generate-sample=True --include-x-ms-examples-original-file=True --python --python-sdks-folder=/home/vsts/work/1/azure-sdk-for-python/sdk --tag=package-2024-05 --use=@autorest/python@6.13.19 --use=@autorest/modelerfour@4.27.0 --version=3.10.2 --version-tolerant=False", + "autorest_command": "autorest specification/cosmos-db/resource-manager/readme.md --generate-sample=True --generate-test=True --include-x-ms-examples-original-file=True --python --python-sdks-folder=/home/vsts/work/1/azure-sdk-for-python/sdk --tag=package-2024-08 --use=@autorest/python@6.19.0 --use=@autorest/modelerfour@4.27.0 --version=3.10.2 --version-tolerant=False", "readme": "specification/cosmos-db/resource-manager/readme.md" } \ No newline at end of file diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_configuration.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_configuration.py index 0cb6a524eee..f91c65b1ff3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_configuration.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_configuration.py @@ -28,13 +28,13 @@ class CosmosDBManagementClientConfiguration: # pylint: disable=too-many-instanc :type credential: ~azure.core.credentials.TokenCredential :param subscription_id: The ID of the target subscription. Required. :type subscription_id: str - :keyword api_version: Api Version. Default value is "2024-05-15". Note that overriding this + :keyword api_version: Api Version. Default value is "2024-08-15". Note that overriding this default value may result in unsupported behavior. :paramtype api_version: str """ def __init__(self, credential: "TokenCredential", subscription_id: str, **kwargs: Any) -> None: - api_version: str = kwargs.pop("api_version", "2024-05-15") + api_version: str = kwargs.pop("api_version", "2024-08-15") if credential is None: raise ValueError("Parameter 'credential' must not be None.") diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_cosmos_db_management_client.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_cosmos_db_management_client.py index 6198471e5ce..eefe7c15344 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_cosmos_db_management_client.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_cosmos_db_management_client.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import Any, TYPE_CHECKING +from typing_extensions import Self from azure.core.pipeline import policies from azure.core.rest import HttpRequest, HttpResponse @@ -161,7 +162,7 @@ class CosmosDBManagementClient: # pylint: disable=client-accepts-api-version-ke :type subscription_id: str :param base_url: Service URL. Default value is "https://management.azure.com". :type base_url: str - :keyword api_version: Api Version. Default value is "2024-05-15". Note that overriding this + :keyword api_version: Api Version. Default value is "2024-08-15". Note that overriding this default value may result in unsupported behavior. :paramtype api_version: str :keyword int polling_interval: Default waiting time between two polls for LRO operations if no @@ -323,7 +324,7 @@ class CosmosDBManagementClient: # pylint: disable=client-accepts-api-version-ke def close(self) -> None: self._client.close() - def __enter__(self) -> "CosmosDBManagementClient": + def __enter__(self) -> Self: self._client.__enter__() return self diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_serialization.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_serialization.py index f0c6180722c..8139854b97b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_serialization.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_serialization.py @@ -144,6 +144,8 @@ class RawDeserializer: # context otherwise. _LOGGER.critical("Wasn't XML not JSON, failing") raise DeserializationError("XML is invalid") from err + elif content_type.startswith("text/"): + return data_as_str raise DeserializationError("Cannot deserialize content-type: {}".format(content_type)) @classmethod diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_vendor.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_vendor.py deleted file mode 100644 index 0dafe0e287f..00000000000 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_vendor.py +++ /dev/null @@ -1,16 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# Code generated by Microsoft (R) AutoRest Code Generator. -# Changes may cause incorrect behavior and will be lost if the code is regenerated. -# -------------------------------------------------------------------------- - -from azure.core.pipeline.transport import HttpRequest - - -def _convert_request(request, files=None): - data = request.content if not files else None - request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data) - if files: - request.set_formdata_body(files) - return request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_version.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_version.py index 7c492199f55..1154dc0c527 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_version.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "9.5.1" +VERSION = "9.6.0" diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_configuration.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_configuration.py index 697fb89f2b6..016bb54d71a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_configuration.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_configuration.py @@ -28,13 +28,13 @@ class CosmosDBManagementClientConfiguration: # pylint: disable=too-many-instanc :type credential: ~azure.core.credentials_async.AsyncTokenCredential :param subscription_id: The ID of the target subscription. Required. :type subscription_id: str - :keyword api_version: Api Version. Default value is "2024-05-15". Note that overriding this + :keyword api_version: Api Version. Default value is "2024-08-15". Note that overriding this default value may result in unsupported behavior. :paramtype api_version: str """ def __init__(self, credential: "AsyncTokenCredential", subscription_id: str, **kwargs: Any) -> None: - api_version: str = kwargs.pop("api_version", "2024-05-15") + api_version: str = kwargs.pop("api_version", "2024-08-15") if credential is None: raise ValueError("Parameter 'credential' must not be None.") diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_cosmos_db_management_client.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_cosmos_db_management_client.py index 8c45a55e89c..87ac3503b32 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_cosmos_db_management_client.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/_cosmos_db_management_client.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import Any, Awaitable, TYPE_CHECKING +from typing_extensions import Self from azure.core.pipeline import policies from azure.core.rest import AsyncHttpResponse, HttpRequest @@ -164,7 +165,7 @@ class CosmosDBManagementClient: # pylint: disable=client-accepts-api-version-ke :type subscription_id: str :param base_url: Service URL. Default value is "https://management.azure.com". :type base_url: str - :keyword api_version: Api Version. Default value is "2024-05-15". Note that overriding this + :keyword api_version: Api Version. Default value is "2024-08-15". Note that overriding this default value may result in unsupported behavior. :paramtype api_version: str :keyword int polling_interval: Default waiting time between two polls for LRO operations if no @@ -328,7 +329,7 @@ class CosmosDBManagementClient: # pylint: disable=client-accepts-api-version-ke async def close(self) -> None: await self._client.close() - async def __aenter__(self) -> "CosmosDBManagementClient": + async def __aenter__(self) -> Self: await self._client.__aenter__() return self diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_clusters_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_clusters_operations.py index f50c901bd97..8c695f20a17 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_clusters_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_clusters_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._cassandra_clusters_operations import ( build_create_update_request, build_deallocate_request, @@ -103,7 +103,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -119,7 +118,6 @@ class CassandraClustersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -185,7 +183,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -201,7 +198,6 @@ class CassandraClustersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -265,7 +261,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -279,16 +274,14 @@ class CassandraClustersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _delete_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + async def _delete_initial(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -301,7 +294,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -311,10 +304,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -322,11 +315,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> AsyncLROPoller[None]: @@ -350,7 +351,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -359,6 +360,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -386,7 +388,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.ClusterResource, IO[bytes]], **kwargs: Any - ) -> _models.ClusterResource: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -400,7 +402,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.ClusterResource] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -421,10 +423,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -432,14 +434,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [200, 201]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("ClusterResource", pipeline_response) - - if response.status_code == 201: - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -551,10 +553,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -582,7 +585,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.ClusterResource, IO[bytes]], **kwargs: Any - ) -> _models.ClusterResource: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -596,7 +599,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.ClusterResource] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -617,10 +620,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -628,14 +631,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("ClusterResource", pipeline_response) - - if response.status_code == 202: - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -742,10 +745,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -773,7 +777,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.CommandPostBody, IO[bytes]], **kwargs: Any - ) -> _models.CommandOutput: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -787,7 +791,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.CommandOutput] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -808,10 +812,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -819,10 +823,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CommandOutput", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -929,10 +937,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CommandOutput", pipeline_response) + deserialized = self._deserialize("CommandOutput", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -954,9 +963,9 @@ class CassandraClustersOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _deallocate_initial( # pylint: disable=inconsistent-return-statements + async def _deallocate_initial( self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -969,7 +978,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_deallocate_request( resource_group_name=resource_group_name, @@ -979,10 +988,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -990,11 +999,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_deallocate( @@ -1022,7 +1039,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._deallocate_initial( # type: ignore + raw_result = await self._deallocate_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -1031,6 +1048,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1052,9 +1070,7 @@ class CassandraClustersOperations: ) return AsyncLROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - async def _start_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + async def _start_initial(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1067,7 +1083,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_start_request( resource_group_name=resource_group_name, @@ -1077,10 +1093,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1088,11 +1104,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_start(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> AsyncLROPoller[None]: @@ -1118,7 +1142,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._start_initial( # type: ignore + raw_result = await self._start_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -1127,6 +1151,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1185,7 +1210,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1199,7 +1223,7 @@ class CassandraClustersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraClusterPublicStatus", pipeline_response) + deserialized = self._deserialize("CassandraClusterPublicStatus", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_data_centers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_data_centers_operations.py index ae9ad2941b8..093aa15cfd5 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_data_centers_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_data_centers_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._cassandra_data_centers_operations import ( build_create_update_request, build_delete_request, @@ -107,7 +107,6 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -123,7 +122,6 @@ class CassandraDataCentersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -192,7 +190,6 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -206,16 +203,16 @@ class CassandraDataCentersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _delete_initial( # pylint: disable=inconsistent-return-statements + async def _delete_initial( self, resource_group_name: str, cluster_name: str, data_center_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -228,7 +225,7 @@ class CassandraDataCentersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -239,10 +236,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -250,11 +247,19 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete( @@ -282,7 +287,7 @@ class CassandraDataCentersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, data_center_name=data_center_name, @@ -292,6 +297,7 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -320,7 +326,7 @@ class CassandraDataCentersOperations: data_center_name: str, body: Union[_models.DataCenterResource, IO[bytes]], **kwargs: Any - ) -> _models.DataCenterResource: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -334,7 +340,7 @@ class CassandraDataCentersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DataCenterResource] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -356,10 +362,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -367,14 +373,14 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [200, 201]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("DataCenterResource", pipeline_response) - - if response.status_code == 201: - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -494,10 +500,11 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -526,7 +533,7 @@ class CassandraDataCentersOperations: data_center_name: str, body: Union[_models.DataCenterResource, IO[bytes]], **kwargs: Any - ) -> _models.DataCenterResource: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -540,7 +547,7 @@ class CassandraDataCentersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DataCenterResource] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -562,10 +569,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -573,14 +580,14 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("DataCenterResource", pipeline_response) - - if response.status_code == 202: - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -697,10 +704,11 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_resources_operations.py index 5975ec1005b..ad400dc181b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_cassandra_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._cassandra_resources_operations import ( build_create_update_cassandra_keyspace_request, build_create_update_cassandra_table_request, @@ -120,7 +120,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -136,7 +135,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -206,7 +204,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -220,7 +217,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) + deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -234,7 +231,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods keyspace_name: str, create_update_cassandra_keyspace_parameters: Union[_models.CassandraKeyspaceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.CassandraKeyspaceGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -248,7 +245,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.CassandraKeyspaceGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -272,10 +269,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -283,20 +280,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -420,10 +419,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) + deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -445,9 +445,9 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_cassandra_keyspace_initial( # pylint: disable=inconsistent-return-statements + async def _delete_cassandra_keyspace_initial( self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -460,7 +460,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_cassandra_keyspace_request( resource_group_name=resource_group_name, @@ -471,10 +471,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -482,6 +482,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -492,8 +496,12 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_cassandra_keyspace( @@ -521,7 +529,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_cassandra_keyspace_initial( # type: ignore + raw_result = await self._delete_cassandra_keyspace_initial( resource_group_name=resource_group_name, account_name=account_name, keyspace_name=keyspace_name, @@ -531,6 +539,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -593,7 +602,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -607,7 +615,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -621,7 +629,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods keyspace_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -635,7 +643,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -657,10 +665,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -668,20 +676,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -805,10 +815,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -832,7 +843,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_cassandra_keyspace_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -845,7 +856,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_keyspace_to_autoscale_request( resource_group_name=resource_group_name, @@ -856,10 +867,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -867,20 +878,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -924,10 +937,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -951,7 +965,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_cassandra_keyspace_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -964,7 +978,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_keyspace_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -975,10 +989,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -986,20 +1000,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1043,10 +1059,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1113,7 +1130,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1129,7 +1145,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1201,7 +1216,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1215,7 +1229,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) + deserialized = self._deserialize("CassandraTableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1230,7 +1244,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods table_name: str, create_update_cassandra_table_parameters: Union[_models.CassandraTableCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.CassandraTableGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1244,7 +1258,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.CassandraTableGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1269,10 +1283,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1280,20 +1294,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1427,10 +1443,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) + deserialized = self._deserialize("CassandraTableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1452,9 +1469,9 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_cassandra_table_initial( # pylint: disable=inconsistent-return-statements + async def _delete_cassandra_table_initial( self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1467,7 +1484,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_cassandra_table_request( resource_group_name=resource_group_name, @@ -1479,10 +1496,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1490,6 +1507,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1500,8 +1521,12 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_cassandra_table( @@ -1531,7 +1556,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_cassandra_table_initial( # type: ignore + raw_result = await self._delete_cassandra_table_initial( resource_group_name=resource_group_name, account_name=account_name, keyspace_name=keyspace_name, @@ -1542,6 +1567,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1607,7 +1633,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1621,7 +1646,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1636,7 +1661,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods table_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1650,7 +1675,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1673,10 +1698,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1684,20 +1709,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1831,10 +1858,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1858,7 +1886,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_cassandra_table_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1871,7 +1899,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_table_to_autoscale_request( resource_group_name=resource_group_name, @@ -1883,10 +1911,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1894,20 +1922,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1954,10 +1984,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1981,7 +2012,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_cassandra_table_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1994,7 +2025,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_table_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2006,10 +2037,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2017,20 +2048,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2077,10 +2110,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_operations.py index 21c41cc10e8..1e58e81d93e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._collection_operations import ( build_list_metric_definitions_request, build_list_metrics_request, @@ -119,7 +117,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -135,7 +132,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -221,7 +217,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -237,7 +232,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -312,7 +306,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -328,7 +321,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_operations.py index 7b7abd6331a..d6146c5c07b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._collection_partition_operations import build_list_metrics_request, build_list_usages_request if sys.version_info >= (3, 9): @@ -115,7 +113,6 @@ class CollectionPartitionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -131,7 +128,6 @@ class CollectionPartitionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -217,7 +213,6 @@ class CollectionPartitionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -233,7 +228,6 @@ class CollectionPartitionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_region_operations.py index dbad34849bd..e74c836fb53 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_partition_region_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._collection_partition_region_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -119,7 +117,6 @@ class CollectionPartitionRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -135,7 +132,6 @@ class CollectionPartitionRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_region_operations.py index ace8fe22009..2047546baa8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_collection_region_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._collection_region_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -119,7 +117,6 @@ class CollectionRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -135,7 +132,6 @@ class CollectionRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_account_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_account_region_operations.py index b3d031ce17d..5c4d6917bf8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_account_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_account_region_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._database_account_region_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -105,7 +103,6 @@ class DatabaseAccountRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -121,7 +118,6 @@ class DatabaseAccountRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_accounts_operations.py index fb11f01e657..2007764906f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_accounts_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_accounts_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._database_accounts_operations import ( build_check_name_exists_request, build_create_or_update_request, @@ -117,7 +117,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -131,7 +130,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -144,7 +143,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods account_name: str, update_parameters: Union[_models.DatabaseAccountUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.DatabaseAccountGetResults: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -158,7 +157,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DatabaseAccountGetResults] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -179,10 +178,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -190,10 +189,14 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -304,10 +307,11 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -335,7 +339,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_parameters: Union[_models.DatabaseAccountCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.DatabaseAccountGetResults: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -349,7 +353,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DatabaseAccountGetResults] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -370,10 +374,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -381,10 +385,14 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -501,10 +509,11 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -526,9 +535,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + async def _delete_initial(self, resource_group_name: str, account_name: str, **kwargs: Any) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -541,7 +548,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -551,10 +558,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -562,6 +569,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -572,8 +583,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete(self, resource_group_name: str, account_name: str, **kwargs: Any) -> AsyncLROPoller[None]: @@ -597,7 +612,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, api_version=api_version, @@ -606,6 +621,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -627,13 +643,13 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) return AsyncLROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - async def _failover_priority_change_initial( # pylint: disable=inconsistent-return-statements + async def _failover_priority_change_initial( self, resource_group_name: str, account_name: str, failover_parameters: Union[_models.FailoverPolicies, IO[bytes]], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -647,7 +663,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -668,10 +684,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -679,6 +695,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -689,8 +709,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload async def begin_failover_priority_change( @@ -787,7 +811,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._failover_priority_change_initial( # type: ignore + raw_result = await self._failover_priority_change_initial( resource_group_name=resource_group_name, account_name=account_name, failover_parameters=failover_parameters, @@ -798,6 +822,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -852,7 +877,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -868,7 +892,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -936,7 +959,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -952,7 +974,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1018,7 +1039,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1032,7 +1052,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1076,7 +1096,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1090,20 +1109,20 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListConnectionStringsResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListConnectionStringsResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _offline_region_initial( # pylint: disable=inconsistent-return-statements + async def _offline_region_initial( self, resource_group_name: str, account_name: str, region_parameter_for_offline: Union[_models.RegionForOnlineOffline, IO[bytes]], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1117,7 +1136,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1138,10 +1157,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1149,6 +1168,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) @@ -1160,8 +1183,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload async def begin_offline_region( @@ -1252,7 +1279,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._offline_region_initial( # type: ignore + raw_result = await self._offline_region_initial( resource_group_name=resource_group_name, account_name=account_name, region_parameter_for_offline=region_parameter_for_offline, @@ -1263,6 +1290,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1284,13 +1312,13 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) return AsyncLROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - async def _online_region_initial( # pylint: disable=inconsistent-return-statements + async def _online_region_initial( self, resource_group_name: str, account_name: str, region_parameter_for_online: Union[_models.RegionForOnlineOffline, IO[bytes]], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1304,7 +1332,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1325,10 +1353,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1336,6 +1364,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) @@ -1347,8 +1379,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload async def begin_online_region( @@ -1439,7 +1475,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._online_region_initial( # type: ignore + raw_result = await self._online_region_initial( resource_group_name=resource_group_name, account_name=account_name, region_parameter_for_online=region_parameter_for_online, @@ -1450,6 +1486,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1508,7 +1545,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1522,7 +1558,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1566,7 +1602,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1580,20 +1615,20 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _regenerate_key_initial( # pylint: disable=inconsistent-return-statements + async def _regenerate_key_initial( self, resource_group_name: str, account_name: str, key_to_regenerate: Union[_models.DatabaseAccountRegenerateKeyParameters, IO[bytes]], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1607,7 +1642,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1628,10 +1663,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1639,6 +1674,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1649,8 +1688,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload async def begin_regenerate_key( @@ -1739,7 +1782,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._regenerate_key_initial( # type: ignore + raw_result = await self._regenerate_key_initial( resource_group_name=resource_group_name, account_name=account_name, key_to_regenerate=key_to_regenerate, @@ -1750,6 +1793,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1803,7 +1847,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1866,7 +1909,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1882,7 +1924,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1956,7 +1997,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1972,7 +2012,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2041,7 +2080,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2057,7 +2095,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_operations.py index d6596da9800..7b1205f4526 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_database_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._database_operations import ( build_list_metric_definitions_request, build_list_metrics_request, @@ -110,7 +108,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -126,7 +123,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -208,7 +204,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -224,7 +219,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -296,7 +290,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -312,7 +305,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_gremlin_resources_operations.py index a799913a504..8763dd96ee7 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_gremlin_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_gremlin_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._gremlin_resources_operations import ( build_create_update_gremlin_database_request, build_create_update_gremlin_graph_request, @@ -121,7 +121,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -137,7 +136,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -207,7 +205,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -221,7 +218,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -235,7 +232,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_gremlin_database_parameters: Union[_models.GremlinDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.GremlinDatabaseGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -249,7 +246,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.GremlinDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -273,10 +270,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -284,20 +281,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -421,10 +420,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -446,9 +446,9 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_gremlin_database_initial( # pylint: disable=inconsistent-return-statements + async def _delete_gremlin_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -461,7 +461,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_gremlin_database_request( resource_group_name=resource_group_name, @@ -472,10 +472,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -483,6 +483,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -493,8 +497,12 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_gremlin_database( @@ -522,7 +530,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_gremlin_database_initial( # type: ignore + raw_result = await self._delete_gremlin_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -532,6 +540,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -594,7 +603,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -608,7 +616,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -622,7 +630,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -636,7 +644,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -658,10 +666,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -669,20 +677,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -806,10 +816,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -833,7 +844,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_gremlin_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -846,7 +857,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -857,10 +868,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -868,20 +879,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -925,10 +938,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -952,7 +966,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_gremlin_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -965,7 +979,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -976,10 +990,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -987,20 +1001,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1044,10 +1060,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1114,7 +1131,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1130,7 +1146,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1202,7 +1217,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1216,7 +1230,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) + deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1231,7 +1245,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, create_update_gremlin_graph_parameters: Union[_models.GremlinGraphCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.GremlinGraphGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1245,7 +1259,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.GremlinGraphGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1268,10 +1282,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1279,20 +1293,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1423,10 +1439,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) + deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1448,9 +1465,9 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_gremlin_graph_initial( # pylint: disable=inconsistent-return-statements + async def _delete_gremlin_graph_initial( self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1463,7 +1480,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_gremlin_graph_request( resource_group_name=resource_group_name, @@ -1475,10 +1492,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1486,6 +1503,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1496,8 +1517,12 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_gremlin_graph( @@ -1527,7 +1552,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_gremlin_graph_initial( # type: ignore + raw_result = await self._delete_gremlin_graph_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -1538,6 +1563,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1603,7 +1629,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1617,7 +1642,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1632,7 +1657,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1646,7 +1671,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1669,10 +1694,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1680,20 +1705,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1827,10 +1854,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1854,7 +1882,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_gremlin_graph_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1867,7 +1895,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_graph_to_autoscale_request( resource_group_name=resource_group_name, @@ -1879,10 +1907,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1890,20 +1918,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1950,10 +1980,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1977,7 +2008,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_gremlin_graph_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1990,7 +2021,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_graph_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2002,10 +2033,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2013,20 +2044,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2073,10 +2106,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2106,7 +2140,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2120,7 +2154,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2143,10 +2177,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2154,12 +2188,14 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2286,10 +2322,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_locations_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_locations_operations.py index 1df34b58e95..dfcc88306b6 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_locations_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_locations_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._locations_operations import build_get_request, build_list_request if sys.version_info >= (3, 9): @@ -89,7 +87,6 @@ class LocationsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -105,7 +102,6 @@ class LocationsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -166,7 +162,6 @@ class LocationsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -180,7 +175,7 @@ class LocationsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("LocationGetResult", pipeline_response) + deserialized = self._deserialize("LocationGetResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_mongo_db_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_mongo_db_resources_operations.py index a1190c5a91a..8bf04def918 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_mongo_db_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_mongo_db_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._mongo_db_resources_operations import ( build_create_update_mongo_db_collection_request, build_create_update_mongo_db_database_request, @@ -129,7 +129,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -145,7 +144,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -215,7 +213,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -229,7 +226,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -243,7 +240,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_mongo_db_database_parameters: Union[_models.MongoDBDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.MongoDBDatabaseGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -257,7 +254,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoDBDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -281,10 +278,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -292,20 +289,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -429,10 +428,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -454,9 +454,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_mongo_db_database_initial( # pylint: disable=inconsistent-return-statements + async def _delete_mongo_db_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -469,7 +469,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_db_database_request( resource_group_name=resource_group_name, @@ -480,10 +480,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -491,6 +491,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -501,8 +505,12 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_mongo_db_database( @@ -530,7 +538,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_mongo_db_database_initial( # type: ignore + raw_result = await self._delete_mongo_db_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -540,6 +548,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -602,7 +611,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -616,7 +624,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -630,7 +638,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -644,7 +652,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -666,10 +674,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -677,20 +685,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -814,10 +824,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -841,7 +852,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_mongo_db_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -854,7 +865,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -865,10 +876,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -876,20 +887,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -933,10 +946,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -960,7 +974,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_mongo_db_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -973,7 +987,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -984,10 +998,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -995,20 +1009,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1052,10 +1068,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1122,7 +1139,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1138,7 +1154,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1210,7 +1225,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1224,7 +1238,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1239,7 +1253,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, create_update_mongo_db_collection_parameters: Union[_models.MongoDBCollectionCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.MongoDBCollectionGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1253,7 +1267,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoDBCollectionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1278,10 +1292,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1289,20 +1303,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1436,10 +1452,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1461,9 +1478,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_mongo_db_collection_initial( # pylint: disable=inconsistent-return-statements + async def _delete_mongo_db_collection_initial( self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1476,7 +1493,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_db_collection_request( resource_group_name=resource_group_name, @@ -1488,10 +1505,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1499,6 +1516,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1509,8 +1530,12 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_mongo_db_collection( @@ -1540,7 +1565,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_mongo_db_collection_initial( # type: ignore + raw_result = await self._delete_mongo_db_collection_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -1551,6 +1576,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1616,7 +1642,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1630,7 +1655,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1645,7 +1670,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1659,7 +1684,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1682,10 +1707,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1693,20 +1718,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1840,10 +1867,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1867,7 +1895,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_mongo_db_collection_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1880,7 +1908,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_collection_to_autoscale_request( resource_group_name=resource_group_name, @@ -1892,10 +1920,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1903,20 +1931,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1963,10 +1993,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1990,7 +2021,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_mongo_db_collection_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2003,7 +2034,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_collection_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2015,10 +2046,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2026,20 +2057,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2086,10 +2119,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2152,7 +2186,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2166,7 +2199,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2182,7 +2215,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _models.MongoRoleDefinitionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.MongoRoleDefinitionGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2196,7 +2229,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoRoleDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2220,10 +2253,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2231,12 +2264,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2363,10 +2398,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2388,9 +2424,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_mongo_role_definition_initial( # pylint: disable=inconsistent-return-statements + async def _delete_mongo_role_definition_initial( self, mongo_role_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2403,7 +2439,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_role_definition_request( mongo_role_definition_id=mongo_role_definition_id, @@ -2414,10 +2450,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2425,11 +2461,19 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_mongo_role_definition( @@ -2457,7 +2501,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_mongo_role_definition_initial( # type: ignore + raw_result = await self._delete_mongo_role_definition_initial( mongo_role_definition_id=mongo_role_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -2467,6 +2511,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2530,7 +2575,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2546,7 +2590,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2616,7 +2659,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2630,7 +2672,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2646,7 +2688,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _models.MongoUserDefinitionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.MongoUserDefinitionGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2660,7 +2702,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoUserDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2684,10 +2726,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2695,12 +2737,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2827,10 +2871,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2852,9 +2897,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_mongo_user_definition_initial( # pylint: disable=inconsistent-return-statements + async def _delete_mongo_user_definition_initial( self, mongo_user_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2867,7 +2912,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_user_definition_request( mongo_user_definition_id=mongo_user_definition_id, @@ -2878,10 +2923,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2889,11 +2934,19 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_mongo_user_definition( @@ -2921,7 +2974,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_mongo_user_definition_initial( # type: ignore + raw_result = await self._delete_mongo_user_definition_initial( mongo_user_definition_id=mongo_user_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -2931,6 +2984,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2994,7 +3048,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3010,7 +3063,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3047,7 +3099,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3061,7 +3113,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3084,10 +3136,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3095,12 +3147,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3227,10 +3281,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_notebook_workspaces_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_notebook_workspaces_operations.py index e83410263e3..bbd7e6ab073 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_notebook_workspaces_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_notebook_workspaces_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._notebook_workspaces_operations import ( build_create_or_update_request, build_delete_request, @@ -109,7 +109,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -125,7 +124,6 @@ class NotebookWorkspacesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -200,7 +198,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -215,7 +212,7 @@ class NotebookWorkspacesOperations: error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = self._deserialize("NotebookWorkspace", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -229,7 +226,7 @@ class NotebookWorkspacesOperations: notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], notebook_create_update_parameters: Union[_models.NotebookWorkspaceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.NotebookWorkspace: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -243,7 +240,7 @@ class NotebookWorkspacesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.NotebookWorkspace] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -265,10 +262,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -276,11 +273,15 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -405,10 +406,11 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = self._deserialize("NotebookWorkspace", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -430,13 +432,13 @@ class NotebookWorkspacesOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_initial( # pylint: disable=inconsistent-return-statements + async def _delete_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -449,7 +451,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -460,10 +462,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -471,12 +473,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete( @@ -509,7 +519,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -519,6 +529,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -585,7 +596,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -600,20 +610,20 @@ class NotebookWorkspacesOperations: error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspaceConnectionInfoResult", pipeline_response) + deserialized = self._deserialize("NotebookWorkspaceConnectionInfoResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _regenerate_auth_token_initial( # pylint: disable=inconsistent-return-statements + async def _regenerate_auth_token_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -626,7 +636,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_regenerate_auth_token_request( resource_group_name=resource_group_name, @@ -637,10 +647,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -648,12 +658,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_regenerate_auth_token( @@ -686,7 +704,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._regenerate_auth_token_initial( # type: ignore + raw_result = await self._regenerate_auth_token_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -696,6 +714,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -717,13 +736,13 @@ class NotebookWorkspacesOperations: ) return AsyncLROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - async def _start_initial( # pylint: disable=inconsistent-return-statements + async def _start_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -736,7 +755,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_start_request( resource_group_name=resource_group_name, @@ -747,10 +766,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -758,12 +777,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_start( @@ -796,7 +823,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._start_initial( # type: ignore + raw_result = await self._start_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -806,6 +833,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_operations.py index 6270309b2b1..374941add85 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._operations import build_list_request if sys.version_info >= (3, 9): @@ -87,7 +85,6 @@ class Operations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -103,7 +100,6 @@ class Operations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_operations.py index 452f3560970..e5719f865a4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._partition_key_range_id_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -118,7 +116,6 @@ class PartitionKeyRangeIdOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -134,7 +131,6 @@ class PartitionKeyRangeIdOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_region_operations.py index 5413cd15ad2..81cab31b89d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_partition_key_range_id_region_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._partition_key_range_id_region_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -123,7 +121,6 @@ class PartitionKeyRangeIdRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -139,7 +136,6 @@ class PartitionKeyRangeIdRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_operations.py index 73cdcdb5061..9d20844c217 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._percentile_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -103,7 +101,6 @@ class PercentileOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -119,7 +116,6 @@ class PercentileOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_source_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_source_target_operations.py index 032ba19b5f1..7703bb57011 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_source_target_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_source_target_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._percentile_source_target_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -117,7 +115,6 @@ class PercentileSourceTargetOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -133,7 +130,6 @@ class PercentileSourceTargetOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_target_operations.py index 3d57f4b07e4..a3b1ae14db5 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_target_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_percentile_target_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._percentile_target_operations import build_list_metrics_request if sys.version_info >= (3, 9): @@ -107,7 +105,6 @@ class PercentileTargetOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -123,7 +120,6 @@ class PercentileTargetOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_endpoint_connections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_endpoint_connections_operations.py index 74b12f3e92f..972d2bd2c7d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_endpoint_connections_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_endpoint_connections_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._private_endpoint_connections_operations import ( build_create_or_update_request, build_delete_request, @@ -108,7 +108,6 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -124,7 +123,6 @@ class PrivateEndpointConnectionsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -193,7 +191,6 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -207,7 +204,7 @@ class PrivateEndpointConnectionsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -221,7 +218,7 @@ class PrivateEndpointConnectionsOperations: private_endpoint_connection_name: str, parameters: Union[_models.PrivateEndpointConnection, IO[bytes]], **kwargs: Any - ) -> Optional[_models.PrivateEndpointConnection]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -235,7 +232,7 @@ class PrivateEndpointConnectionsOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.PrivateEndpointConnection]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -257,10 +254,10 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -268,13 +265,15 @@ class PrivateEndpointConnectionsOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) - error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) + error = self._deserialize.failsafe_deserialize(_models.ErrorResponseAutoGenerated, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -393,10 +392,11 @@ class PrivateEndpointConnectionsOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -418,9 +418,9 @@ class PrivateEndpointConnectionsOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_initial( # pylint: disable=inconsistent-return-statements + async def _delete_initial( self, resource_group_name: str, account_name: str, private_endpoint_connection_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -433,7 +433,7 @@ class PrivateEndpointConnectionsOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -444,10 +444,10 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -455,12 +455,20 @@ class PrivateEndpointConnectionsOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) - error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) + error = self._deserialize.failsafe_deserialize(_models.ErrorResponseAutoGenerated, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete( @@ -488,7 +496,7 @@ class PrivateEndpointConnectionsOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, private_endpoint_connection_name=private_endpoint_connection_name, @@ -498,6 +506,7 @@ class PrivateEndpointConnectionsOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_link_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_link_resources_operations.py index 0bf50d2c22b..e89692eb682 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_link_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_private_link_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._private_link_resources_operations import build_get_request, build_list_by_database_account_request if sys.version_info >= (3, 9): @@ -99,7 +97,6 @@ class PrivateLinkResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -115,7 +112,6 @@ class PrivateLinkResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -184,7 +180,6 @@ class PrivateLinkResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -198,7 +193,7 @@ class PrivateLinkResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("PrivateLinkResource", pipeline_response) + deserialized = self._deserialize("PrivateLinkResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_database_accounts_operations.py index 9dd7a454acd..74d32712b83 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_database_accounts_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_database_accounts_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_database_accounts_operations import ( build_get_by_location_request, build_list_by_location_request, @@ -103,7 +101,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -119,7 +116,6 @@ class RestorableDatabaseAccountsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -182,7 +178,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -198,7 +193,6 @@ class RestorableDatabaseAccountsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -265,7 +259,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -279,7 +272,7 @@ class RestorableDatabaseAccountsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("RestorableDatabaseAccountGetResult", pipeline_response) + deserialized = self._deserialize("RestorableDatabaseAccountGetResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_databases_operations.py index 9e2befca1c7..e20776bddcb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_databases_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_gremlin_databases_operations import build_list_request if sys.version_info >= (3, 9): @@ -102,7 +100,6 @@ class RestorableGremlinDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -118,7 +115,6 @@ class RestorableGremlinDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_graphs_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_graphs_operations.py index cc2db438123..9727d2e9c26 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_graphs_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_graphs_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_gremlin_graphs_operations import build_list_request if sys.version_info >= (3, 9): @@ -117,7 +115,6 @@ class RestorableGremlinGraphsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -133,7 +130,6 @@ class RestorableGremlinGraphsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_resources_operations.py index 477549082bb..bed6e396aa7 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_gremlin_resources_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_gremlin_resources_operations import build_list_request if sys.version_info >= (3, 9): @@ -115,7 +113,6 @@ class RestorableGremlinResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -131,7 +128,6 @@ class RestorableGremlinResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_collections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_collections_operations.py index 61e023c2971..53fbdd05a92 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_collections_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_collections_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_mongodb_collections_operations import build_list_request if sys.version_info >= (3, 9): @@ -117,7 +115,6 @@ class RestorableMongodbCollectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -133,7 +130,6 @@ class RestorableMongodbCollectionsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_databases_operations.py index 4cc2ea3dfc4..e8696dcf1c0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_databases_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_mongodb_databases_operations import build_list_request if sys.version_info >= (3, 9): @@ -102,7 +100,6 @@ class RestorableMongodbDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -118,7 +115,6 @@ class RestorableMongodbDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_resources_operations.py index cf2bb998cd0..3bd82fca524 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_mongodb_resources_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_mongodb_resources_operations import build_list_request if sys.version_info >= (3, 9): @@ -115,7 +113,6 @@ class RestorableMongodbResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -131,7 +128,6 @@ class RestorableMongodbResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_containers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_containers_operations.py index 590cdfbe33b..38f893922e0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_containers_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_containers_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_sql_containers_operations import build_list_request if sys.version_info >= (3, 9): @@ -116,7 +114,6 @@ class RestorableSqlContainersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -132,7 +129,6 @@ class RestorableSqlContainersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_databases_operations.py index 3fbbda93533..bfb7c3f659c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_databases_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_sql_databases_operations import build_list_request if sys.version_info >= (3, 9): @@ -102,7 +100,6 @@ class RestorableSqlDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -118,7 +115,6 @@ class RestorableSqlDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_resources_operations.py index e4d01217a44..18589426021 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_sql_resources_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_sql_resources_operations import build_list_request if sys.version_info >= (3, 9): @@ -115,7 +113,6 @@ class RestorableSqlResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -131,7 +128,6 @@ class RestorableSqlResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_table_resources_operations.py index 00db572b892..45f0907202c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_table_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_table_resources_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_table_resources_operations import build_list_request if sys.version_info >= (3, 9): @@ -114,7 +112,6 @@ class RestorableTableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -130,7 +127,6 @@ class RestorableTableResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_tables_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_tables_operations.py index db4b59ceeb3..868a6a90b92 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_tables_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_restorable_tables_operations.py @@ -20,14 +20,12 @@ from azure.core.exceptions import ( map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from ... import models as _models -from ..._vendor import _convert_request from ...operations._restorable_tables_operations import build_list_request if sys.version_info >= (3, 9): @@ -112,7 +110,6 @@ class RestorableTablesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -128,7 +125,6 @@ class RestorableTablesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_service_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_service_operations.py index ec32140a6c3..b48c4472638 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_service_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_service_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._service_operations import ( build_create_request, build_delete_request, @@ -106,7 +106,6 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -122,7 +121,6 @@ class ServiceOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -158,7 +156,7 @@ class ServiceOperations: service_name: str, create_update_parameters: Union[_models.ServiceResourceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ServiceResource]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -172,7 +170,7 @@ class ServiceOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ServiceResource]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -194,10 +192,10 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -205,20 +203,22 @@ class ServiceOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ServiceResource", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -336,10 +336,11 @@ class ServiceOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ServiceResource", pipeline_response) + deserialized = self._deserialize("ServiceResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -401,7 +402,6 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -415,16 +415,16 @@ class ServiceOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ServiceResource", pipeline_response) + deserialized = self._deserialize("ServiceResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - async def _delete_initial( # pylint: disable=inconsistent-return-statements + async def _delete_initial( self, resource_group_name: str, account_name: str, service_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -437,7 +437,7 @@ class ServiceOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -448,10 +448,10 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -459,6 +459,10 @@ class ServiceOperations: response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -469,8 +473,12 @@ class ServiceOperations: ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete( @@ -498,7 +506,7 @@ class ServiceOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_initial( # type: ignore + raw_result = await self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, service_name=service_name, @@ -508,6 +516,7 @@ class ServiceOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_sql_resources_operations.py index d2e3d697932..67cafa788b0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_sql_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_sql_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._sql_resources_operations import ( build_create_update_client_encryption_key_request, build_create_update_sql_container_request, @@ -144,7 +144,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -160,7 +159,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -230,7 +228,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -244,7 +241,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -258,7 +255,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_sql_database_parameters: Union[_models.SqlDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlDatabaseGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -272,7 +269,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -294,10 +291,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -305,20 +302,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -438,10 +437,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -463,9 +463,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_database_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -478,7 +478,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_database_request( resource_group_name=resource_group_name, @@ -489,10 +489,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -500,6 +500,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -510,8 +514,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_database( @@ -539,7 +547,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_database_initial( # type: ignore + raw_result = await self._delete_sql_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -549,6 +557,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -611,7 +620,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -625,7 +633,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -639,7 +647,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -653,7 +661,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -675,10 +683,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -686,20 +694,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -823,10 +833,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -850,7 +861,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_sql_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -863,7 +874,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -874,10 +885,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -885,20 +896,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -942,10 +955,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -969,7 +983,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_sql_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -982,7 +996,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -993,10 +1007,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1004,20 +1018,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1061,10 +1077,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1131,7 +1148,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1147,7 +1163,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1219,7 +1234,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1233,7 +1247,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) + deserialized = self._deserialize("SqlContainerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1248,7 +1262,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, create_update_sql_container_parameters: Union[_models.SqlContainerCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlContainerGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1262,7 +1276,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlContainerGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1285,10 +1299,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1296,20 +1310,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1439,10 +1455,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) + deserialized = self._deserialize("SqlContainerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1464,9 +1481,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_container_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_container_initial( self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1479,7 +1496,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_container_request( resource_group_name=resource_group_name, @@ -1491,10 +1508,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1502,6 +1519,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1512,8 +1533,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_container( @@ -1543,7 +1568,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_container_initial( # type: ignore + raw_result = await self._delete_sql_container_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -1554,6 +1579,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1619,7 +1645,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1633,7 +1658,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1648,7 +1673,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1662,7 +1687,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1685,10 +1710,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1696,20 +1721,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1843,10 +1870,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1870,7 +1898,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_sql_container_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1883,7 +1911,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_container_to_autoscale_request( resource_group_name=resource_group_name, @@ -1895,10 +1923,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1906,20 +1934,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1966,10 +1996,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1993,7 +2024,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods async def _migrate_sql_container_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2006,7 +2037,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_container_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2018,10 +2049,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2029,20 +2060,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2089,10 +2122,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2159,7 +2193,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2175,7 +2208,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2252,7 +2284,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2266,7 +2297,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) + deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2283,7 +2314,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.ClientEncryptionKeyCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.ClientEncryptionKeyGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2297,7 +2328,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ClientEncryptionKeyGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2322,10 +2353,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2333,20 +2364,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2485,10 +2518,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) + deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2558,7 +2592,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2574,7 +2607,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2655,7 +2687,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2669,7 +2700,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) + deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2687,7 +2718,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.SqlStoredProcedureCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.SqlStoredProcedureGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2701,7 +2732,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlStoredProcedureGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2727,10 +2758,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2738,20 +2769,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2897,10 +2930,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) + deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2922,7 +2956,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_stored_procedure_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_stored_procedure_initial( self, resource_group_name: str, account_name: str, @@ -2930,7 +2964,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, stored_procedure_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2943,7 +2977,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_stored_procedure_request( resource_group_name=resource_group_name, @@ -2956,10 +2990,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2967,6 +3001,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2977,8 +3015,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_stored_procedure( @@ -3016,7 +3058,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_stored_procedure_initial( # type: ignore + raw_result = await self._delete_sql_stored_procedure_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -3028,6 +3070,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -3097,7 +3140,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3113,7 +3155,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3194,7 +3235,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3208,7 +3248,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) + deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3226,7 +3266,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.SqlUserDefinedFunctionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.SqlUserDefinedFunctionGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3240,7 +3280,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlUserDefinedFunctionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3266,10 +3306,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3277,20 +3317,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -3436,10 +3478,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) + deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3461,7 +3504,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_user_defined_function_initial( # pylint: disable=inconsistent-return-statements,name-too-long + async def _delete_sql_user_defined_function_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, @@ -3469,7 +3512,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, user_defined_function_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3482,7 +3525,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_user_defined_function_request( resource_group_name=resource_group_name, @@ -3495,10 +3538,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3506,6 +3549,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -3516,8 +3563,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_user_defined_function( @@ -3555,7 +3606,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_user_defined_function_initial( # type: ignore + raw_result = await self._delete_sql_user_defined_function_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -3567,6 +3618,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -3636,7 +3688,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3652,7 +3703,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3733,7 +3783,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3747,7 +3796,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) + deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3763,7 +3812,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods trigger_name: str, create_update_sql_trigger_parameters: Union[_models.SqlTriggerCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlTriggerGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3777,7 +3826,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlTriggerGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3801,10 +3850,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3812,20 +3861,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -3965,10 +4016,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) + deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3990,7 +4042,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_trigger_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_trigger_initial( self, resource_group_name: str, account_name: str, @@ -3998,7 +4050,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, trigger_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4011,7 +4063,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_trigger_request( resource_group_name=resource_group_name, @@ -4024,10 +4076,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4035,6 +4087,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -4045,8 +4101,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_trigger( @@ -4084,7 +4144,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_trigger_initial( # type: ignore + raw_result = await self._delete_sql_trigger_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -4096,6 +4156,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -4157,7 +4218,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -4171,7 +4231,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4185,7 +4245,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_sql_role_definition_parameters: Union[_models.SqlRoleDefinitionCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlRoleDefinitionGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4199,7 +4259,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlRoleDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -4223,10 +4283,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4234,12 +4294,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4364,10 +4426,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -4389,9 +4452,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_role_definition_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_role_definition_initial( self, role_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4404,7 +4467,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_role_definition_request( role_definition_id=role_definition_id, @@ -4415,10 +4478,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4426,11 +4489,19 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_role_definition( @@ -4458,7 +4529,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_role_definition_initial( # type: ignore + raw_result = await self._delete_sql_role_definition_initial( role_definition_id=role_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -4468,6 +4539,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -4531,7 +4603,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -4547,7 +4618,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -4616,7 +4686,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -4630,7 +4699,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4644,7 +4713,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_sql_role_assignment_parameters: Union[_models.SqlRoleAssignmentCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlRoleAssignmentGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4658,7 +4727,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlRoleAssignmentGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -4682,10 +4751,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4693,12 +4762,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4823,10 +4894,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -4848,9 +4920,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_sql_role_assignment_initial( # pylint: disable=inconsistent-return-statements + async def _delete_sql_role_assignment_initial( self, role_assignment_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4863,7 +4935,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_role_assignment_request( role_assignment_id=role_assignment_id, @@ -4874,10 +4946,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4885,11 +4957,19 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_sql_role_assignment( @@ -4917,7 +4997,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_sql_role_assignment_initial( # type: ignore + raw_result = await self._delete_sql_role_assignment_initial( role_assignment_id=role_assignment_id, resource_group_name=resource_group_name, account_name=account_name, @@ -4927,6 +5007,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -4990,7 +5071,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -5006,7 +5086,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -5043,7 +5122,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5057,7 +5136,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -5080,10 +5159,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5091,12 +5170,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -5223,10 +5304,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_table_resources_operations.py index 4e5a6af288e..f1f8f4a797c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_table_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/aio/operations/_table_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.async_paging import AsyncItemPaged, AsyncList @@ -18,12 +18,13 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import AsyncHttpResponse from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.utils import case_insensitive_dict @@ -31,7 +32,6 @@ from azure.mgmt.core.exceptions import ARMErrorFormat from azure.mgmt.core.polling.async_arm_polling import AsyncARMPolling from ... import models as _models -from ..._vendor import _convert_request from ...operations._table_resources_operations import ( build_create_update_table_request, build_delete_table_request, @@ -111,7 +111,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -127,7 +126,6 @@ class TableResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -196,7 +194,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -210,7 +207,7 @@ class TableResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("TableGetResults", pipeline_response) + deserialized = self._deserialize("TableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -224,7 +221,7 @@ class TableResourcesOperations: table_name: str, create_update_table_parameters: Union[_models.TableCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.TableGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -238,7 +235,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.TableGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -260,10 +257,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -271,20 +268,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("TableGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -403,10 +402,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("TableGetResults", pipeline_response) + deserialized = self._deserialize("TableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -428,9 +428,9 @@ class TableResourcesOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - async def _delete_table_initial( # pylint: disable=inconsistent-return-statements + async def _delete_table_initial( self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> None: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -443,7 +443,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_table_request( resource_group_name=resource_group_name, @@ -454,10 +454,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -465,6 +465,10 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -475,8 +479,12 @@ class TableResourcesOperations: ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace_async async def begin_delete_table( @@ -504,7 +512,7 @@ class TableResourcesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = await self._delete_table_initial( # type: ignore + raw_result = await self._delete_table_initial( resource_group_name=resource_group_name, account_name=account_name, table_name=table_name, @@ -514,6 +522,7 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -576,7 +585,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -590,7 +598,7 @@ class TableResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -604,7 +612,7 @@ class TableResourcesOperations: table_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -618,7 +626,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -640,10 +648,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -651,20 +659,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -788,10 +798,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -815,7 +826,7 @@ class TableResourcesOperations: async def _migrate_table_to_autoscale_initial( self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -828,7 +839,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_table_to_autoscale_request( resource_group_name=resource_group_name, @@ -839,10 +850,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -850,20 +861,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -907,10 +920,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -934,7 +948,7 @@ class TableResourcesOperations: async def _migrate_table_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -947,7 +961,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_table_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -958,10 +972,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -969,20 +983,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1026,10 +1042,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1058,7 +1075,7 @@ class TableResourcesOperations: table_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> AsyncIterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1072,7 +1089,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1094,10 +1111,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1105,12 +1122,14 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1227,10 +1246,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + await raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/__init__.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/__init__.py index 7471e35315a..5c4187ac41f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/__init__.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/__init__.py @@ -81,7 +81,10 @@ from ._models_py3 import DatabaseAccountRegenerateKeyParameters from ._models_py3 import DatabaseAccountUpdateParameters from ._models_py3 import DatabaseAccountsListResult from ._models_py3 import DatabaseRestoreResource +from ._models_py3 import ErrorAdditionalInfo +from ._models_py3 import ErrorDetail from ._models_py3 import ErrorResponse +from ._models_py3 import ErrorResponseAutoGenerated from ._models_py3 import ExcludedPath from ._models_py3 import ExtendedResourceProperties from ._models_py3 import FailoverPolicies @@ -408,7 +411,10 @@ __all__ = [ "DatabaseAccountUpdateParameters", "DatabaseAccountsListResult", "DatabaseRestoreResource", + "ErrorAdditionalInfo", + "ErrorDetail", "ErrorResponse", + "ErrorResponseAutoGenerated", "ExcludedPath", "ExtendedResourceProperties", "FailoverPolicies", diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_cosmos_db_management_client_enums.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_cosmos_db_management_client_enums.py index eeaa26ece86..9957c5fd5d7 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_cosmos_db_management_client_enums.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_cosmos_db_management_client_enums.py @@ -324,7 +324,7 @@ class RoleDefinitionType(str, Enum, metaclass=CaseInsensitiveEnumMeta): class ServerVersion(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Describes the ServerVersion of an a MongoDB account.""" + """Describes the version of the MongoDB account.""" THREE2 = "3.2" THREE6 = "3.6" @@ -332,6 +332,7 @@ class ServerVersion(str, Enum, metaclass=CaseInsensitiveEnumMeta): FOUR2 = "4.2" FIVE0 = "5.0" SIX0 = "6.0" + SEVEN0 = "7.0" class ServiceSize(str, Enum, metaclass=CaseInsensitiveEnumMeta): diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_models_py3.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_models_py3.py index cdadb0d91c8..efbc35393d0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_models_py3.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/models/_models_py3.py @@ -69,8 +69,8 @@ class AnalyticalStorageConfiguration(_serialization.Model): class ApiProperties(_serialization.Model): """ApiProperties. - :ivar server_version: Describes the ServerVersion of an a MongoDB account. Known values are: - "3.2", "3.6", "4.0", "4.2", "5.0", and "6.0". + :ivar server_version: Describes the version of the MongoDB account. Known values are: "3.2", + "3.6", "4.0", "4.2", "5.0", "6.0", and "7.0". :vartype server_version: str or ~azure.mgmt.cosmosdb.models.ServerVersion """ @@ -80,8 +80,8 @@ class ApiProperties(_serialization.Model): def __init__(self, *, server_version: Optional[Union[str, "_models.ServerVersion"]] = None, **kwargs: Any) -> None: """ - :keyword server_version: Describes the ServerVersion of an a MongoDB account. Known values are: - "3.2", "3.6", "4.0", "4.2", "5.0", and "6.0". + :keyword server_version: Describes the version of the MongoDB account. Known values are: "3.2", + "3.6", "4.0", "4.2", "5.0", "6.0", and "7.0". :paramtype server_version: str or ~azure.mgmt.cosmosdb.models.ServerVersion """ super().__init__(**kwargs) @@ -2955,8 +2955,8 @@ class DatabaseAccountCreateUpdateParameters(ARMResourceProperties): # pylint: d 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :vartype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature on + the account. :vartype enable_burst_capacity: bool :ivar customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -3153,8 +3153,8 @@ class DatabaseAccountCreateUpdateParameters(ARMResourceProperties): # pylint: d Tls 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :paramtype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature + on the account. :paramtype enable_burst_capacity: bool :keyword customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -3328,8 +3328,8 @@ class DatabaseAccountGetResults(ARMResourceProperties): # pylint: disable=too-m 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :vartype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature on + the account. :vartype enable_burst_capacity: bool :ivar customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -3539,8 +3539,8 @@ class DatabaseAccountGetResults(ARMResourceProperties): # pylint: disable=too-m Tls 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :paramtype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature + on the account. :paramtype enable_burst_capacity: bool :keyword customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -3863,8 +3863,8 @@ class DatabaseAccountUpdateParameters(_serialization.Model): # pylint: disable= 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :vartype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :ivar enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature on + the account. :vartype enable_burst_capacity: bool :ivar customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -4036,8 +4036,8 @@ class DatabaseAccountUpdateParameters(_serialization.Model): # pylint: disable= Tls 1.2. Cassandra and Mongo APIs only work with Tls 1.2. Known values are: "Tls", "Tls11", and "Tls12". :paramtype minimal_tls_version: str or ~azure.mgmt.cosmosdb.models.MinimalTlsVersion - :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity Preview - feature on the account. + :keyword enable_burst_capacity: Flag to indicate enabling/disabling of Burst Capacity feature + on the account. :paramtype enable_burst_capacity: bool :keyword customer_managed_key_status: Indicates the status of the Customer Managed Key feature on the account. In case there are errors, the property provides troubleshooting guidance. @@ -4649,6 +4649,77 @@ class DataTransferServiceResourceProperties(ServiceResourceProperties): self.locations = None +class ErrorAdditionalInfo(_serialization.Model): + """The resource management error additional info. + + Variables are only populated by the server, and will be ignored when sending a request. + + :ivar type: The additional info type. + :vartype type: str + :ivar info: The additional info. + :vartype info: JSON + """ + + _validation = { + "type": {"readonly": True}, + "info": {"readonly": True}, + } + + _attribute_map = { + "type": {"key": "type", "type": "str"}, + "info": {"key": "info", "type": "object"}, + } + + def __init__(self, **kwargs: Any) -> None: + """ """ + super().__init__(**kwargs) + self.type = None + self.info = None + + +class ErrorDetail(_serialization.Model): + """The error detail. + + Variables are only populated by the server, and will be ignored when sending a request. + + :ivar code: The error code. + :vartype code: str + :ivar message: The error message. + :vartype message: str + :ivar target: The error target. + :vartype target: str + :ivar details: The error details. + :vartype details: list[~azure.mgmt.cosmosdb.models.ErrorDetail] + :ivar additional_info: The error additional info. + :vartype additional_info: list[~azure.mgmt.cosmosdb.models.ErrorAdditionalInfo] + """ + + _validation = { + "code": {"readonly": True}, + "message": {"readonly": True}, + "target": {"readonly": True}, + "details": {"readonly": True}, + "additional_info": {"readonly": True}, + } + + _attribute_map = { + "code": {"key": "code", "type": "str"}, + "message": {"key": "message", "type": "str"}, + "target": {"key": "target", "type": "str"}, + "details": {"key": "details", "type": "[ErrorDetail]"}, + "additional_info": {"key": "additionalInfo", "type": "[ErrorAdditionalInfo]"}, + } + + def __init__(self, **kwargs: Any) -> None: + """ """ + super().__init__(**kwargs) + self.code = None + self.message = None + self.target = None + self.details = None + self.additional_info = None + + class ErrorResponse(_serialization.Model): """Error Response. @@ -4675,6 +4746,27 @@ class ErrorResponse(_serialization.Model): self.message = message +class ErrorResponseAutoGenerated(_serialization.Model): + """Common error response for all Azure Resource Manager APIs to return error details for failed + operations. (This also follows the OData error response format.). + + :ivar error: The error object. + :vartype error: ~azure.mgmt.cosmosdb.models.ErrorDetail + """ + + _attribute_map = { + "error": {"key": "error", "type": "ErrorDetail"}, + } + + def __init__(self, *, error: Optional["_models.ErrorDetail"] = None, **kwargs: Any) -> None: + """ + :keyword error: The error object. + :paramtype error: ~azure.mgmt.cosmosdb.models.ErrorDetail + """ + super().__init__(**kwargs) + self.error = error + + class ExcludedPath(_serialization.Model): """ExcludedPath. @@ -8775,11 +8867,15 @@ class RestoreParametersBase(_serialization.Model): :vartype restore_source: str :ivar restore_timestamp_in_utc: Time to which the account has to be restored (ISO-8601 format). :vartype restore_timestamp_in_utc: ~datetime.datetime + :ivar restore_with_ttl_disabled: Specifies whether the restored account will have Time-To-Live + disabled upon the successful restore. + :vartype restore_with_ttl_disabled: bool """ _attribute_map = { "restore_source": {"key": "restoreSource", "type": "str"}, "restore_timestamp_in_utc": {"key": "restoreTimestampInUtc", "type": "iso-8601"}, + "restore_with_ttl_disabled": {"key": "restoreWithTtlDisabled", "type": "bool"}, } def __init__( @@ -8787,6 +8883,7 @@ class RestoreParametersBase(_serialization.Model): *, restore_source: Optional[str] = None, restore_timestamp_in_utc: Optional[datetime.datetime] = None, + restore_with_ttl_disabled: Optional[bool] = None, **kwargs: Any ) -> None: """ @@ -8797,10 +8894,14 @@ class RestoreParametersBase(_serialization.Model): :keyword restore_timestamp_in_utc: Time to which the account has to be restored (ISO-8601 format). :paramtype restore_timestamp_in_utc: ~datetime.datetime + :keyword restore_with_ttl_disabled: Specifies whether the restored account will have + Time-To-Live disabled upon the successful restore. + :paramtype restore_with_ttl_disabled: bool """ super().__init__(**kwargs) self.restore_source = restore_source self.restore_timestamp_in_utc = restore_timestamp_in_utc + self.restore_with_ttl_disabled = restore_with_ttl_disabled class ResourceRestoreParameters(RestoreParametersBase): @@ -8812,6 +8913,9 @@ class ResourceRestoreParameters(RestoreParametersBase): :vartype restore_source: str :ivar restore_timestamp_in_utc: Time to which the account has to be restored (ISO-8601 format). :vartype restore_timestamp_in_utc: ~datetime.datetime + :ivar restore_with_ttl_disabled: Specifies whether the restored account will have Time-To-Live + disabled upon the successful restore. + :vartype restore_with_ttl_disabled: bool """ @@ -10498,6 +10602,9 @@ class RestoreParameters(RestoreParametersBase): :vartype restore_source: str :ivar restore_timestamp_in_utc: Time to which the account has to be restored (ISO-8601 format). :vartype restore_timestamp_in_utc: ~datetime.datetime + :ivar restore_with_ttl_disabled: Specifies whether the restored account will have Time-To-Live + disabled upon the successful restore. + :vartype restore_with_ttl_disabled: bool :ivar restore_mode: Describes the mode of the restore. "PointInTime" :vartype restore_mode: str or ~azure.mgmt.cosmosdb.models.RestoreMode :ivar databases_to_restore: List of specific databases available for restore. @@ -10512,6 +10619,7 @@ class RestoreParameters(RestoreParametersBase): _attribute_map = { "restore_source": {"key": "restoreSource", "type": "str"}, "restore_timestamp_in_utc": {"key": "restoreTimestampInUtc", "type": "iso-8601"}, + "restore_with_ttl_disabled": {"key": "restoreWithTtlDisabled", "type": "bool"}, "restore_mode": {"key": "restoreMode", "type": "str"}, "databases_to_restore": {"key": "databasesToRestore", "type": "[DatabaseRestoreResource]"}, "gremlin_databases_to_restore": { @@ -10526,6 +10634,7 @@ class RestoreParameters(RestoreParametersBase): *, restore_source: Optional[str] = None, restore_timestamp_in_utc: Optional[datetime.datetime] = None, + restore_with_ttl_disabled: Optional[bool] = None, restore_mode: Optional[Union[str, "_models.RestoreMode"]] = None, databases_to_restore: Optional[List["_models.DatabaseRestoreResource"]] = None, gremlin_databases_to_restore: Optional[List["_models.GremlinDatabaseRestoreResource"]] = None, @@ -10540,6 +10649,9 @@ class RestoreParameters(RestoreParametersBase): :keyword restore_timestamp_in_utc: Time to which the account has to be restored (ISO-8601 format). :paramtype restore_timestamp_in_utc: ~datetime.datetime + :keyword restore_with_ttl_disabled: Specifies whether the restored account will have + Time-To-Live disabled upon the successful restore. + :paramtype restore_with_ttl_disabled: bool :keyword restore_mode: Describes the mode of the restore. "PointInTime" :paramtype restore_mode: str or ~azure.mgmt.cosmosdb.models.RestoreMode :keyword databases_to_restore: List of specific databases available for restore. @@ -10551,7 +10663,12 @@ class RestoreParameters(RestoreParametersBase): :keyword tables_to_restore: List of specific tables available for restore. :paramtype tables_to_restore: list[str] """ - super().__init__(restore_source=restore_source, restore_timestamp_in_utc=restore_timestamp_in_utc, **kwargs) + super().__init__( + restore_source=restore_source, + restore_timestamp_in_utc=restore_timestamp_in_utc, + restore_with_ttl_disabled=restore_with_ttl_disabled, + **kwargs + ) self.restore_mode = restore_mode self.databases_to_restore = databases_to_restore self.gremlin_databases_to_restore = gremlin_databases_to_restore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_clusters_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_clusters_operations.py index 5061742402c..efc8557f2c0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_clusters_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_clusters_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -48,7 +48,7 @@ def build_list_by_subscription_request(subscription_id: str, **kwargs: Any) -> H _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -74,7 +74,7 @@ def build_list_by_resource_group_request(resource_group_name: str, subscription_ _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -104,7 +104,7 @@ def build_get_request(resource_group_name: str, cluster_name: str, subscription_ _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -139,7 +139,7 @@ def build_delete_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -174,7 +174,7 @@ def build_create_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -212,7 +212,7 @@ def build_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -250,7 +250,7 @@ def build_invoke_command_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -288,7 +288,7 @@ def build_deallocate_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -323,7 +323,7 @@ def build_start_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -358,7 +358,7 @@ def build_status_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -437,7 +437,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -453,7 +452,6 @@ class CassandraClustersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -517,7 +515,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -533,7 +530,6 @@ class CassandraClustersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -597,7 +593,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -611,16 +606,14 @@ class CassandraClustersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _delete_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + def _delete_initial(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -633,7 +626,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -643,10 +636,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -654,11 +647,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> LROPoller[None]: @@ -682,7 +683,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -691,6 +692,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -718,7 +720,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.ClusterResource, IO[bytes]], **kwargs: Any - ) -> _models.ClusterResource: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -732,7 +734,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.ClusterResource] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -753,10 +755,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -764,14 +766,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [200, 201]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("ClusterResource", pipeline_response) - - if response.status_code == 201: - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -883,10 +885,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -914,7 +917,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.ClusterResource, IO[bytes]], **kwargs: Any - ) -> _models.ClusterResource: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -928,7 +931,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.ClusterResource] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -949,10 +952,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -960,14 +963,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("ClusterResource", pipeline_response) - - if response.status_code == 202: - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1074,10 +1077,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClusterResource", pipeline_response) + deserialized = self._deserialize("ClusterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1105,7 +1109,7 @@ class CassandraClustersOperations: cluster_name: str, body: Union[_models.CommandPostBody, IO[bytes]], **kwargs: Any - ) -> _models.CommandOutput: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1119,7 +1123,7 @@ class CassandraClustersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.CommandOutput] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1140,10 +1144,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1151,10 +1155,14 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CommandOutput", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1261,10 +1269,11 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CommandOutput", pipeline_response) + deserialized = self._deserialize("CommandOutput", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1286,9 +1295,7 @@ class CassandraClustersOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _deallocate_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + def _deallocate_initial(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1301,7 +1308,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_deallocate_request( resource_group_name=resource_group_name, @@ -1311,10 +1318,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1322,11 +1329,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_deallocate(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> LROPoller[None]: @@ -1352,7 +1367,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._deallocate_initial( # type: ignore + raw_result = self._deallocate_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -1361,6 +1376,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1382,9 +1398,7 @@ class CassandraClustersOperations: ) return LROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - def _start_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, cluster_name: str, **kwargs: Any - ) -> None: + def _start_initial(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1397,7 +1411,7 @@ class CassandraClustersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_start_request( resource_group_name=resource_group_name, @@ -1407,10 +1421,10 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1418,11 +1432,19 @@ class CassandraClustersOperations: response = pipeline_response.http_response if response.status_code not in [202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_start(self, resource_group_name: str, cluster_name: str, **kwargs: Any) -> LROPoller[None]: @@ -1448,7 +1470,7 @@ class CassandraClustersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._start_initial( # type: ignore + raw_result = self._start_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, api_version=api_version, @@ -1457,6 +1479,7 @@ class CassandraClustersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1515,7 +1538,6 @@ class CassandraClustersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1529,7 +1551,7 @@ class CassandraClustersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraClusterPublicStatus", pipeline_response) + deserialized = self._deserialize("CassandraClusterPublicStatus", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_data_centers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_data_centers_operations.py index 1a6529a9cd2..a23316628a1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_data_centers_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_data_centers_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -48,7 +48,7 @@ def build_list_request(resource_group_name: str, cluster_name: str, subscription _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -83,7 +83,7 @@ def build_get_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -126,7 +126,7 @@ def build_delete_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -169,7 +169,7 @@ def build_create_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -215,7 +215,7 @@ def build_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -314,7 +314,6 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -330,7 +329,6 @@ class CassandraDataCentersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -399,7 +397,6 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -413,16 +410,16 @@ class CassandraDataCentersOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _delete_initial( # pylint: disable=inconsistent-return-statements + def _delete_initial( self, resource_group_name: str, cluster_name: str, data_center_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -435,7 +432,7 @@ class CassandraDataCentersOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -446,10 +443,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -457,11 +454,19 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete( @@ -489,7 +494,7 @@ class CassandraDataCentersOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, cluster_name=cluster_name, data_center_name=data_center_name, @@ -499,6 +504,7 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -527,7 +533,7 @@ class CassandraDataCentersOperations: data_center_name: str, body: Union[_models.DataCenterResource, IO[bytes]], **kwargs: Any - ) -> _models.DataCenterResource: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -541,7 +547,7 @@ class CassandraDataCentersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DataCenterResource] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -563,10 +569,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -574,14 +580,14 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [200, 201]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("DataCenterResource", pipeline_response) - - if response.status_code == 201: - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -701,10 +707,11 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -733,7 +740,7 @@ class CassandraDataCentersOperations: data_center_name: str, body: Union[_models.DataCenterResource, IO[bytes]], **kwargs: Any - ) -> _models.DataCenterResource: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -747,7 +754,7 @@ class CassandraDataCentersOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DataCenterResource] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -769,10 +776,10 @@ class CassandraDataCentersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -780,14 +787,14 @@ class CassandraDataCentersOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - if response.status_code == 200: - deserialized = self._deserialize("DataCenterResource", pipeline_response) - - if response.status_code == 202: - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -904,10 +911,11 @@ class CassandraDataCentersOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DataCenterResource", pipeline_response) + deserialized = self._deserialize("DataCenterResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_resources_operations.py index 7a55872d205..9785bfec05c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_cassandra_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_cassandra_keyspaces_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -85,7 +85,7 @@ def build_get_cassandra_keyspace_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -121,7 +121,7 @@ def build_create_update_cassandra_keyspace_request( # pylint: disable=name-too- _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -159,7 +159,7 @@ def build_delete_cassandra_keyspace_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -190,7 +190,7 @@ def build_get_cassandra_keyspace_throughput_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +226,7 @@ def build_update_cassandra_keyspace_throughput_request( # pylint: disable=name- _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -265,7 +265,7 @@ def build_migrate_cassandra_keyspace_to_autoscale_request( # pylint: disable=na _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -301,7 +301,7 @@ def build_migrate_cassandra_keyspace_to_manual_throughput_request( # pylint: di _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -337,7 +337,7 @@ def build_list_cassandra_tables_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -378,7 +378,7 @@ def build_get_cassandra_table_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -420,7 +420,7 @@ def build_create_update_cassandra_table_request( # pylint: disable=name-too-lon _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -464,7 +464,7 @@ def build_delete_cassandra_table_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -501,7 +501,7 @@ def build_get_cassandra_table_throughput_request( # pylint: disable=name-too-lo _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -543,7 +543,7 @@ def build_update_cassandra_table_throughput_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -588,7 +588,7 @@ def build_migrate_cassandra_table_to_autoscale_request( # pylint: disable=name- _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -630,7 +630,7 @@ def build_migrate_cassandra_table_to_manual_throughput_request( # pylint: disab _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -721,7 +721,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -737,7 +736,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -807,7 +805,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -821,7 +818,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) + deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -835,7 +832,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods keyspace_name: str, create_update_cassandra_keyspace_parameters: Union[_models.CassandraKeyspaceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.CassandraKeyspaceGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -849,7 +846,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.CassandraKeyspaceGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -873,10 +870,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -884,20 +881,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1018,10 +1017,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response) + deserialized = self._deserialize("CassandraKeyspaceGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1043,9 +1043,9 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_cassandra_keyspace_initial( # pylint: disable=inconsistent-return-statements + def _delete_cassandra_keyspace_initial( self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1058,7 +1058,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_cassandra_keyspace_request( resource_group_name=resource_group_name, @@ -1069,10 +1069,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1080,6 +1080,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1090,8 +1094,12 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_cassandra_keyspace( @@ -1119,7 +1127,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_cassandra_keyspace_initial( # type: ignore + raw_result = self._delete_cassandra_keyspace_initial( resource_group_name=resource_group_name, account_name=account_name, keyspace_name=keyspace_name, @@ -1129,6 +1137,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1191,7 +1200,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1205,7 +1213,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1219,7 +1227,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods keyspace_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1233,7 +1241,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1255,10 +1263,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1266,20 +1274,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1400,10 +1410,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1427,7 +1438,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_cassandra_keyspace_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1440,7 +1451,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_keyspace_to_autoscale_request( resource_group_name=resource_group_name, @@ -1451,10 +1462,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1462,20 +1473,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1518,10 +1531,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1545,7 +1559,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_cassandra_keyspace_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1558,7 +1572,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_keyspace_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -1569,10 +1583,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1580,20 +1594,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1636,10 +1652,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1705,7 +1722,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1721,7 +1737,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1793,7 +1808,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1807,7 +1821,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) + deserialized = self._deserialize("CassandraTableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1822,7 +1836,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods table_name: str, create_update_cassandra_table_parameters: Union[_models.CassandraTableCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.CassandraTableGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1836,7 +1850,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.CassandraTableGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1861,10 +1875,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1872,20 +1886,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2016,10 +2032,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("CassandraTableGetResults", pipeline_response) + deserialized = self._deserialize("CassandraTableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2041,9 +2058,9 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_cassandra_table_initial( # pylint: disable=inconsistent-return-statements + def _delete_cassandra_table_initial( self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2056,7 +2073,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_cassandra_table_request( resource_group_name=resource_group_name, @@ -2068,10 +2085,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2079,6 +2096,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2089,8 +2110,12 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_cassandra_table( @@ -2120,7 +2145,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_cassandra_table_initial( # type: ignore + raw_result = self._delete_cassandra_table_initial( resource_group_name=resource_group_name, account_name=account_name, keyspace_name=keyspace_name, @@ -2131,6 +2156,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2196,7 +2222,6 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2210,7 +2235,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2225,7 +2250,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods table_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2239,7 +2264,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2262,10 +2287,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2273,20 +2298,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2417,10 +2444,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2444,7 +2472,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_cassandra_table_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2457,7 +2485,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_table_to_autoscale_request( resource_group_name=resource_group_name, @@ -2469,10 +2497,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2480,20 +2508,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2539,10 +2569,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2566,7 +2597,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_cassandra_table_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, keyspace_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2579,7 +2610,7 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_cassandra_table_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2591,10 +2622,10 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2602,20 +2633,22 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2661,10 +2694,11 @@ class CassandraResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_operations.py index 2305916c12f..8dcb3927c21 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -99,7 +97,7 @@ def build_list_usages_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -143,7 +141,7 @@ def build_list_metric_definitions_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -251,7 +249,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -267,7 +264,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -353,7 +349,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -369,7 +364,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -444,7 +438,6 @@ class CollectionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -460,7 +453,6 @@ class CollectionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_operations.py index 0de82212ec2..16d710439e3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -99,7 +97,7 @@ def build_list_usages_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -209,7 +207,6 @@ class CollectionPartitionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -225,7 +222,6 @@ class CollectionPartitionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -311,7 +307,6 @@ class CollectionPartitionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -327,7 +322,6 @@ class CollectionPartitionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_region_operations.py index 9b2ac9bef82..2b5df3c8045 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_partition_region_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -55,7 +53,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -169,7 +167,6 @@ class CollectionPartitionRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -185,7 +182,6 @@ class CollectionPartitionRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_region_operations.py index 4e86fbea613..c1ae9995aa4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_collection_region_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -55,7 +53,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -169,7 +167,6 @@ class CollectionRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -185,7 +182,6 @@ class CollectionRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_account_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_account_region_operations.py index f3ee3023a83..829d1060e3d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_account_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_account_region_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -47,7 +45,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -145,7 +143,6 @@ class DatabaseAccountRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -161,7 +158,6 @@ class DatabaseAccountRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_accounts_operations.py index 3bd1da8625c..751ed241e34 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_accounts_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_accounts_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -48,7 +48,7 @@ def build_get_request(resource_group_name: str, account_name: str, subscription_ _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -83,7 +83,7 @@ def build_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -121,7 +121,7 @@ def build_create_or_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -158,7 +158,7 @@ def build_delete_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -188,7 +188,7 @@ def build_failover_priority_change_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL _url = kwargs.pop( @@ -221,7 +221,7 @@ def build_list_request(subscription_id: str, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -245,7 +245,7 @@ def build_list_by_resource_group_request(resource_group_name: str, subscription_ _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -277,7 +277,7 @@ def build_list_keys_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -312,7 +312,7 @@ def build_list_connection_strings_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -347,7 +347,7 @@ def build_offline_region_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -385,7 +385,7 @@ def build_online_region_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -423,7 +423,7 @@ def build_get_read_only_keys_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -458,7 +458,7 @@ def build_list_read_only_keys_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -493,7 +493,7 @@ def build_regenerate_key_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) # Construct URL _url = kwargs.pop( @@ -525,7 +525,7 @@ def build_regenerate_key_request( def build_check_name_exists_request(account_name: str, **kwargs: Any) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop("template_url", "/providers/Microsoft.DocumentDB/databaseAccountNames/{accountName}") path_format_arguments = { @@ -548,7 +548,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -584,7 +584,7 @@ def build_list_usages_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -621,7 +621,7 @@ def build_list_metric_definitions_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -704,7 +704,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -718,7 +717,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -731,7 +730,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods account_name: str, update_parameters: Union[_models.DatabaseAccountUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.DatabaseAccountGetResults: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -745,7 +744,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DatabaseAccountGetResults] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -766,10 +765,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -777,10 +776,14 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -888,10 +891,11 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -919,7 +923,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_parameters: Union[_models.DatabaseAccountCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.DatabaseAccountGetResults: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -933,7 +937,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.DatabaseAccountGetResults] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -954,10 +958,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -965,10 +969,14 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1082,10 +1090,11 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response) + deserialized = self._deserialize("DatabaseAccountGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1107,9 +1116,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_initial( # pylint: disable=inconsistent-return-statements - self, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + def _delete_initial(self, resource_group_name: str, account_name: str, **kwargs: Any) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1122,7 +1129,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -1132,10 +1139,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1143,6 +1150,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1153,8 +1164,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete(self, resource_group_name: str, account_name: str, **kwargs: Any) -> LROPoller[None]: @@ -1178,7 +1193,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, api_version=api_version, @@ -1187,6 +1202,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1208,13 +1224,13 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) return LROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - def _failover_priority_change_initial( # pylint: disable=inconsistent-return-statements + def _failover_priority_change_initial( self, resource_group_name: str, account_name: str, failover_parameters: Union[_models.FailoverPolicies, IO[bytes]], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1228,7 +1244,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1249,10 +1265,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1260,6 +1276,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1270,8 +1290,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload def begin_failover_priority_change( @@ -1368,7 +1392,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._failover_priority_change_initial( # type: ignore + raw_result = self._failover_priority_change_initial( resource_group_name=resource_group_name, account_name=account_name, failover_parameters=failover_parameters, @@ -1379,6 +1403,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1432,7 +1457,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1448,7 +1472,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1515,7 +1538,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1531,7 +1553,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1597,7 +1618,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1611,7 +1631,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1655,7 +1675,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1669,20 +1688,20 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListConnectionStringsResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListConnectionStringsResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _offline_region_initial( # pylint: disable=inconsistent-return-statements + def _offline_region_initial( self, resource_group_name: str, account_name: str, region_parameter_for_offline: Union[_models.RegionForOnlineOffline, IO[bytes]], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1696,7 +1715,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1717,10 +1736,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1728,6 +1747,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) @@ -1739,8 +1762,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload def begin_offline_region( @@ -1831,7 +1858,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._offline_region_initial( # type: ignore + raw_result = self._offline_region_initial( resource_group_name=resource_group_name, account_name=account_name, region_parameter_for_offline=region_parameter_for_offline, @@ -1842,6 +1869,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1863,13 +1891,13 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) return LROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - def _online_region_initial( # pylint: disable=inconsistent-return-statements + def _online_region_initial( self, resource_group_name: str, account_name: str, region_parameter_for_online: Union[_models.RegionForOnlineOffline, IO[bytes]], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1883,7 +1911,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1904,10 +1932,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1915,6 +1943,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) @@ -1926,8 +1958,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload def begin_online_region( @@ -2018,7 +2054,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._online_region_initial( # type: ignore + raw_result = self._online_region_initial( resource_group_name=resource_group_name, account_name=account_name, region_parameter_for_online=region_parameter_for_online, @@ -2029,6 +2065,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2087,7 +2124,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2101,7 +2137,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2145,7 +2181,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2159,20 +2194,20 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response) + deserialized = self._deserialize("DatabaseAccountListReadOnlyKeysResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _regenerate_key_initial( # pylint: disable=inconsistent-return-statements + def _regenerate_key_initial( self, resource_group_name: str, account_name: str, key_to_regenerate: Union[_models.DatabaseAccountRegenerateKeyParameters, IO[bytes]], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2186,7 +2221,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2207,10 +2242,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2218,6 +2253,10 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2228,8 +2267,12 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @overload def begin_regenerate_key( @@ -2318,7 +2361,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._regenerate_key_initial( # type: ignore + raw_result = self._regenerate_key_initial( resource_group_name=resource_group_name, account_name=account_name, key_to_regenerate=key_to_regenerate, @@ -2329,6 +2372,7 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2382,7 +2426,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2445,7 +2488,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2461,7 +2503,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2535,7 +2576,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2551,7 +2591,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2620,7 +2659,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2636,7 +2674,6 @@ class DatabaseAccountsOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_operations.py index 5ac9a39666c..b8a13340d45 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_database_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -47,7 +45,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -90,7 +88,7 @@ def build_list_usages_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -128,7 +126,7 @@ def build_list_metric_definitions_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +224,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -242,7 +239,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -324,7 +320,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -340,7 +335,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -412,7 +406,6 @@ class DatabaseOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -428,7 +421,6 @@ class DatabaseOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_gremlin_resources_operations.py index 0c35ce44e32..2e7682c8c02 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_gremlin_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_gremlin_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_gremlin_databases_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -85,7 +85,7 @@ def build_get_gremlin_database_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -121,7 +121,7 @@ def build_create_update_gremlin_database_request( # pylint: disable=name-too-lo _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -159,7 +159,7 @@ def build_delete_gremlin_database_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -190,7 +190,7 @@ def build_get_gremlin_database_throughput_request( # pylint: disable=name-too-l _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +226,7 @@ def build_update_gremlin_database_throughput_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -265,7 +265,7 @@ def build_migrate_gremlin_database_to_autoscale_request( # pylint: disable=name _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -301,7 +301,7 @@ def build_migrate_gremlin_database_to_manual_throughput_request( # pylint: disa _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -337,7 +337,7 @@ def build_list_gremlin_graphs_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -378,7 +378,7 @@ def build_get_gremlin_graph_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -420,7 +420,7 @@ def build_create_update_gremlin_graph_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -464,7 +464,7 @@ def build_delete_gremlin_graph_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -501,7 +501,7 @@ def build_get_gremlin_graph_throughput_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -543,7 +543,7 @@ def build_update_gremlin_graph_throughput_request( # pylint: disable=name-too-l _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -588,7 +588,7 @@ def build_migrate_gremlin_graph_to_autoscale_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -630,7 +630,7 @@ def build_migrate_gremlin_graph_to_manual_throughput_request( # pylint: disable _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -672,7 +672,7 @@ def build_retrieve_continuous_backup_information_request( # pylint: disable=nam _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -766,7 +766,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -782,7 +781,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -852,7 +850,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -866,7 +863,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -880,7 +877,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_gremlin_database_parameters: Union[_models.GremlinDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.GremlinDatabaseGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -894,7 +891,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.GremlinDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -918,10 +915,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -929,20 +926,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1063,10 +1062,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("GremlinDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1088,9 +1088,9 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_gremlin_database_initial( # pylint: disable=inconsistent-return-statements + def _delete_gremlin_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1103,7 +1103,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_gremlin_database_request( resource_group_name=resource_group_name, @@ -1114,10 +1114,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1125,6 +1125,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1135,8 +1139,12 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_gremlin_database( @@ -1164,7 +1172,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_gremlin_database_initial( # type: ignore + raw_result = self._delete_gremlin_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -1174,6 +1182,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1236,7 +1245,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1250,7 +1258,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1264,7 +1272,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1278,7 +1286,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1300,10 +1308,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1311,20 +1319,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1445,10 +1455,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1472,7 +1483,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_gremlin_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1485,7 +1496,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -1496,10 +1507,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1507,20 +1518,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1563,10 +1576,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1590,7 +1604,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_gremlin_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1603,7 +1617,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -1614,10 +1628,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1625,20 +1639,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1681,10 +1697,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1750,7 +1767,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1766,7 +1782,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1838,7 +1853,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1852,7 +1866,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) + deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1867,7 +1881,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, create_update_gremlin_graph_parameters: Union[_models.GremlinGraphCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.GremlinGraphGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1881,7 +1895,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.GremlinGraphGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1904,10 +1918,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1915,20 +1929,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2059,10 +2075,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response) + deserialized = self._deserialize("GremlinGraphGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2084,9 +2101,9 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_gremlin_graph_initial( # pylint: disable=inconsistent-return-statements + def _delete_gremlin_graph_initial( self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2099,7 +2116,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_gremlin_graph_request( resource_group_name=resource_group_name, @@ -2111,10 +2128,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2122,6 +2139,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2132,8 +2153,12 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_gremlin_graph( @@ -2163,7 +2188,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_gremlin_graph_initial( # type: ignore + raw_result = self._delete_gremlin_graph_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -2174,6 +2199,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2239,7 +2265,6 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2253,7 +2278,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2268,7 +2293,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2282,7 +2307,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2305,10 +2330,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2316,20 +2341,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2460,10 +2487,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2487,7 +2515,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_gremlin_graph_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2500,7 +2528,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_graph_to_autoscale_request( resource_group_name=resource_group_name, @@ -2512,10 +2540,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2523,20 +2551,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2582,10 +2612,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2609,7 +2640,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_gremlin_graph_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, graph_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2622,7 +2653,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_gremlin_graph_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2634,10 +2665,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2645,20 +2676,22 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2704,10 +2737,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2737,7 +2771,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods graph_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2751,7 +2785,7 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2774,10 +2808,10 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2785,12 +2819,14 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2917,10 +2953,11 @@ class GremlinResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_locations_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_locations_operations.py index a704c5d714b..18d9564090a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_locations_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_locations_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_request(subscription_id: str, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -69,7 +67,7 @@ def build_get_request(location: str, subscription_id: str, **kwargs: Any) -> Htt _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -142,7 +140,6 @@ class LocationsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -158,7 +155,6 @@ class LocationsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -219,7 +215,6 @@ class LocationsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -233,7 +228,7 @@ class LocationsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("LocationGetResult", pipeline_response) + deserialized = self._deserialize("LocationGetResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_mongo_db_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_mongo_db_resources_operations.py index 4b87abe9132..28a675a2ae6 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_mongo_db_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_mongo_db_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_mongo_db_databases_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -85,7 +85,7 @@ def build_get_mongo_db_database_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -121,7 +121,7 @@ def build_create_update_mongo_db_database_request( # pylint: disable=name-too-l _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -159,7 +159,7 @@ def build_delete_mongo_db_database_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -190,7 +190,7 @@ def build_get_mongo_db_database_throughput_request( # pylint: disable=name-too- _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +226,7 @@ def build_update_mongo_db_database_throughput_request( # pylint: disable=name-t _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -265,7 +265,7 @@ def build_migrate_mongo_db_database_to_autoscale_request( # pylint: disable=nam _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -301,7 +301,7 @@ def build_migrate_mongo_db_database_to_manual_throughput_request( # pylint: dis _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -337,7 +337,7 @@ def build_list_mongo_db_collections_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -378,7 +378,7 @@ def build_get_mongo_db_collection_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -420,7 +420,7 @@ def build_create_update_mongo_db_collection_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -464,7 +464,7 @@ def build_delete_mongo_db_collection_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -501,7 +501,7 @@ def build_get_mongo_db_collection_throughput_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -543,7 +543,7 @@ def build_update_mongo_db_collection_throughput_request( # pylint: disable=name _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -588,7 +588,7 @@ def build_migrate_mongo_db_collection_to_autoscale_request( # pylint: disable=n _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -630,7 +630,7 @@ def build_migrate_mongo_db_collection_to_manual_throughput_request( # pylint: d _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -667,7 +667,7 @@ def build_get_mongo_role_definition_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -703,7 +703,7 @@ def build_create_update_mongo_role_definition_request( # pylint: disable=name-t _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -742,7 +742,7 @@ def build_delete_mongo_role_definition_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -778,7 +778,7 @@ def build_list_mongo_role_definitions_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -813,7 +813,7 @@ def build_get_mongo_user_definition_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -849,7 +849,7 @@ def build_create_update_mongo_user_definition_request( # pylint: disable=name-t _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -888,7 +888,7 @@ def build_delete_mongo_user_definition_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -924,7 +924,7 @@ def build_list_mongo_user_definitions_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -964,7 +964,7 @@ def build_retrieve_continuous_backup_information_request( # pylint: disable=nam _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1058,7 +1058,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1074,7 +1073,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1144,7 +1142,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1158,7 +1155,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1172,7 +1169,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_mongo_db_database_parameters: Union[_models.MongoDBDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.MongoDBDatabaseGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1186,7 +1183,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoDBDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1210,10 +1207,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1221,20 +1218,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1355,10 +1354,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1380,9 +1380,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_mongo_db_database_initial( # pylint: disable=inconsistent-return-statements + def _delete_mongo_db_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1395,7 +1395,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_db_database_request( resource_group_name=resource_group_name, @@ -1406,10 +1406,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1417,6 +1417,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -1427,8 +1431,12 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_mongo_db_database( @@ -1456,7 +1464,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_mongo_db_database_initial( # type: ignore + raw_result = self._delete_mongo_db_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -1466,6 +1474,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -1528,7 +1537,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1542,7 +1550,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1556,7 +1564,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1570,7 +1578,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1592,10 +1600,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1603,20 +1611,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1737,10 +1747,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1764,7 +1775,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_mongo_db_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1777,7 +1788,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -1788,10 +1799,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1799,20 +1810,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1855,10 +1868,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1882,7 +1896,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_mongo_db_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1895,7 +1909,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -1906,10 +1920,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1917,20 +1931,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1973,10 +1989,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2042,7 +2059,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2058,7 +2074,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2130,7 +2145,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2144,7 +2158,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2159,7 +2173,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, create_update_mongo_db_collection_parameters: Union[_models.MongoDBCollectionCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.MongoDBCollectionGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2173,7 +2187,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoDBCollectionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2198,10 +2212,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2209,20 +2223,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2353,10 +2369,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response) + deserialized = self._deserialize("MongoDBCollectionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2378,9 +2395,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_mongo_db_collection_initial( # pylint: disable=inconsistent-return-statements + def _delete_mongo_db_collection_initial( self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2393,7 +2410,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_db_collection_request( resource_group_name=resource_group_name, @@ -2405,10 +2422,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2416,6 +2433,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2426,8 +2447,12 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_mongo_db_collection( @@ -2457,7 +2482,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_mongo_db_collection_initial( # type: ignore + raw_result = self._delete_mongo_db_collection_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -2468,6 +2493,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2533,7 +2559,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2547,7 +2572,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2562,7 +2587,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2576,7 +2601,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2599,10 +2624,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2610,20 +2635,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2754,10 +2781,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2781,7 +2809,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_mongo_db_collection_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2794,7 +2822,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_collection_to_autoscale_request( resource_group_name=resource_group_name, @@ -2806,10 +2834,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2817,20 +2845,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2876,10 +2906,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2903,7 +2934,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_mongo_db_collection_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, collection_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2916,7 +2947,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_mongo_db_collection_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2928,10 +2959,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2939,20 +2970,22 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2998,10 +3031,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3064,7 +3098,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3078,7 +3111,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3094,7 +3127,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _models.MongoRoleDefinitionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.MongoRoleDefinitionGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3108,7 +3141,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoRoleDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3132,10 +3165,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3143,12 +3176,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3275,10 +3310,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3300,9 +3336,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_mongo_role_definition_initial( # pylint: disable=inconsistent-return-statements + def _delete_mongo_role_definition_initial( self, mongo_role_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3315,7 +3351,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_role_definition_request( mongo_role_definition_id=mongo_role_definition_id, @@ -3326,10 +3362,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3337,11 +3373,19 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_mongo_role_definition( @@ -3369,7 +3413,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_mongo_role_definition_initial( # type: ignore + raw_result = self._delete_mongo_role_definition_initial( mongo_role_definition_id=mongo_role_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -3379,6 +3423,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -3441,7 +3486,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3457,7 +3501,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3527,7 +3570,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3541,7 +3583,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3557,7 +3599,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _models.MongoUserDefinitionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.MongoUserDefinitionGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3571,7 +3613,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.MongoUserDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3595,10 +3637,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3606,12 +3648,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3738,10 +3782,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("MongoUserDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3763,9 +3808,9 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_mongo_user_definition_initial( # pylint: disable=inconsistent-return-statements + def _delete_mongo_user_definition_initial( self, mongo_user_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3778,7 +3823,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_mongo_user_definition_request( mongo_user_definition_id=mongo_user_definition_id, @@ -3789,10 +3834,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3800,11 +3845,19 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_mongo_user_definition( @@ -3832,7 +3885,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_mongo_user_definition_initial( # type: ignore + raw_result = self._delete_mongo_user_definition_initial( mongo_user_definition_id=mongo_user_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -3842,6 +3895,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -3904,7 +3958,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3920,7 +3973,6 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3957,7 +4009,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods collection_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3971,7 +4023,7 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3994,10 +4046,10 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4005,12 +4057,14 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4137,10 +4191,11 @@ class MongoDBResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_notebook_workspaces_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_notebook_workspaces_operations.py index e806ac3360e..66a5157170f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_notebook_workspaces_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_notebook_workspaces_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_by_database_account_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -89,7 +89,7 @@ def build_get_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -129,7 +129,7 @@ def build_create_or_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -172,7 +172,7 @@ def build_delete_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -212,7 +212,7 @@ def build_list_connection_info_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -252,7 +252,7 @@ def build_regenerate_auth_token_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -292,7 +292,7 @@ def build_start_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -381,7 +381,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -397,7 +396,6 @@ class NotebookWorkspacesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -472,7 +470,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -487,7 +484,7 @@ class NotebookWorkspacesOperations: error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = self._deserialize("NotebookWorkspace", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -501,7 +498,7 @@ class NotebookWorkspacesOperations: notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], notebook_create_update_parameters: Union[_models.NotebookWorkspaceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> _models.NotebookWorkspace: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -515,7 +512,7 @@ class NotebookWorkspacesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[_models.NotebookWorkspace] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -537,10 +534,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -548,11 +545,15 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -677,10 +678,11 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("NotebookWorkspace", pipeline_response) + deserialized = self._deserialize("NotebookWorkspace", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -702,13 +704,13 @@ class NotebookWorkspacesOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_initial( # pylint: disable=inconsistent-return-statements + def _delete_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -721,7 +723,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -732,10 +734,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -743,12 +745,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete( @@ -781,7 +791,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -791,6 +801,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -857,7 +868,6 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -872,20 +882,20 @@ class NotebookWorkspacesOperations: error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = self._deserialize("NotebookWorkspaceConnectionInfoResult", pipeline_response) + deserialized = self._deserialize("NotebookWorkspaceConnectionInfoResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _regenerate_auth_token_initial( # pylint: disable=inconsistent-return-statements + def _regenerate_auth_token_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -898,7 +908,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_regenerate_auth_token_request( resource_group_name=resource_group_name, @@ -909,10 +919,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -920,12 +930,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_regenerate_auth_token( @@ -958,7 +976,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._regenerate_auth_token_initial( # type: ignore + raw_result = self._regenerate_auth_token_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -968,6 +986,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -989,13 +1008,13 @@ class NotebookWorkspacesOperations: ) return LROPoller[None](self._client, raw_result, get_long_running_output, polling_method) # type: ignore - def _start_initial( # pylint: disable=inconsistent-return-statements + def _start_initial( self, resource_group_name: str, account_name: str, notebook_workspace_name: Union[str, _models.NotebookWorkspaceName], **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1008,7 +1027,7 @@ class NotebookWorkspacesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_start_request( resource_group_name=resource_group_name, @@ -1019,10 +1038,10 @@ class NotebookWorkspacesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1030,12 +1049,20 @@ class NotebookWorkspacesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_start( @@ -1068,7 +1095,7 @@ class NotebookWorkspacesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._start_initial( # type: ignore + raw_result = self._start_initial( resource_group_name=resource_group_name, account_name=account_name, notebook_workspace_name=notebook_workspace_name, @@ -1078,6 +1105,7 @@ class NotebookWorkspacesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_operations.py index 300b0fdb646..d66f3020882 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -109,7 +107,6 @@ class Operations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -125,7 +122,6 @@ class Operations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_operations.py index b061e1bfc90..bb31f8cebc4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -55,7 +53,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -168,7 +166,6 @@ class PartitionKeyRangeIdOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -184,7 +181,6 @@ class PartitionKeyRangeIdOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_region_operations.py index 79e584de3aa..9548be9c3a2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_region_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_partition_key_range_id_region_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -56,7 +54,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -175,7 +173,6 @@ class PartitionKeyRangeIdRegionOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -191,7 +188,6 @@ class PartitionKeyRangeIdRegionOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_operations.py index a8557627c62..347587af16d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -47,7 +45,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -142,7 +140,6 @@ class PercentileOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -158,7 +155,6 @@ class PercentileOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_source_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_source_target_operations.py index ee8702c68cb..767125c2a9c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_source_target_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_source_target_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -165,7 +163,6 @@ class PercentileSourceTargetOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -181,7 +178,6 @@ class PercentileSourceTargetOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_target_operations.py index f64e0de09ee..c92c6dfea5f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_target_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_percentile_target_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -47,7 +45,7 @@ def build_list_metrics_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -147,7 +145,6 @@ class PercentileTargetOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -163,7 +160,6 @@ class PercentileTargetOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_endpoint_connections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_endpoint_connections_operations.py index f861cead1c1..d0477388345 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_endpoint_connections_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_endpoint_connections_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_by_database_account_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -89,7 +89,7 @@ def build_get_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -131,7 +131,7 @@ def build_create_or_update_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -176,7 +176,7 @@ def build_delete_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -268,7 +268,6 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -284,7 +283,6 @@ class PrivateEndpointConnectionsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -353,7 +351,6 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -367,7 +364,7 @@ class PrivateEndpointConnectionsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -381,7 +378,7 @@ class PrivateEndpointConnectionsOperations: private_endpoint_connection_name: str, parameters: Union[_models.PrivateEndpointConnection, IO[bytes]], **kwargs: Any - ) -> Optional[_models.PrivateEndpointConnection]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -395,7 +392,7 @@ class PrivateEndpointConnectionsOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.PrivateEndpointConnection]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -417,10 +414,10 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -428,13 +425,15 @@ class PrivateEndpointConnectionsOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) - error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) + error = self._deserialize.failsafe_deserialize(_models.ErrorResponseAutoGenerated, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -550,10 +549,11 @@ class PrivateEndpointConnectionsOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response) + deserialized = self._deserialize("PrivateEndpointConnection", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -575,9 +575,9 @@ class PrivateEndpointConnectionsOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_initial( # pylint: disable=inconsistent-return-statements + def _delete_initial( self, resource_group_name: str, account_name: str, private_endpoint_connection_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -590,7 +590,7 @@ class PrivateEndpointConnectionsOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -601,10 +601,10 @@ class PrivateEndpointConnectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -612,12 +612,20 @@ class PrivateEndpointConnectionsOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) - error = self._deserialize.failsafe_deserialize(_models.ErrorResponse, pipeline_response) + error = self._deserialize.failsafe_deserialize(_models.ErrorResponseAutoGenerated, pipeline_response) raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete( @@ -645,7 +653,7 @@ class PrivateEndpointConnectionsOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, private_endpoint_connection_name=private_endpoint_connection_name, @@ -655,6 +663,7 @@ class PrivateEndpointConnectionsOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_link_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_link_resources_operations.py index 3249f4c9ff6..3609926a7b2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_link_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_private_link_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -47,7 +45,7 @@ def build_list_by_database_account_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -82,7 +80,7 @@ def build_get_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -171,7 +169,6 @@ class PrivateLinkResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -187,7 +184,6 @@ class PrivateLinkResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -256,7 +252,6 @@ class PrivateLinkResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -270,7 +265,7 @@ class PrivateLinkResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("PrivateLinkResource", pipeline_response) + deserialized = self._deserialize("PrivateLinkResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_database_accounts_operations.py index 8cd051b8ed0..bd41d28ba25 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_database_accounts_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_database_accounts_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_by_location_request(location: str, subscription_id: str, **kwargs _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -73,7 +71,7 @@ def build_list_request(subscription_id: str, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -99,7 +97,7 @@ def build_get_by_location_request(location: str, instance_id: str, subscription_ _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -182,7 +180,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -198,7 +195,6 @@ class RestorableDatabaseAccountsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -261,7 +257,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -277,7 +272,6 @@ class RestorableDatabaseAccountsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -344,7 +338,6 @@ class RestorableDatabaseAccountsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -358,7 +351,7 @@ class RestorableDatabaseAccountsOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("RestorableDatabaseAccountGetResult", pipeline_response) + deserialized = self._deserialize("RestorableDatabaseAccountGetResult", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_databases_operations.py index ef0590a8ade..fa72e6a2b40 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_databases_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_request(location: str, instance_id: str, subscription_id: str, ** _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -134,7 +132,6 @@ class RestorableGremlinDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -150,7 +147,6 @@ class RestorableGremlinDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_graphs_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_graphs_operations.py index 2c4f835fc01..fc9bb5bcdeb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_graphs_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_graphs_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -166,7 +164,6 @@ class RestorableGremlinGraphsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -182,7 +179,6 @@ class RestorableGremlinGraphsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_resources_operations.py index c9877dad1b3..58997b34088 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_gremlin_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -53,7 +51,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -161,7 +159,6 @@ class RestorableGremlinResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -177,7 +174,6 @@ class RestorableGremlinResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_collections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_collections_operations.py index 57b39e7efc9..04bd3e408fe 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_collections_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_collections_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -166,7 +164,6 @@ class RestorableMongodbCollectionsOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -182,7 +179,6 @@ class RestorableMongodbCollectionsOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_databases_operations.py index ae59bd932d2..413da88e355 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_databases_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_request(location: str, instance_id: str, subscription_id: str, ** _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -134,7 +132,6 @@ class RestorableMongodbDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -150,7 +147,6 @@ class RestorableMongodbDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_resources_operations.py index 92209f4fdc8..eb552e716c3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_mongodb_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -53,7 +51,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -161,7 +159,6 @@ class RestorableMongodbResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -177,7 +174,6 @@ class RestorableMongodbResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_containers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_containers_operations.py index a03339e2eeb..fc7922e44c1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_containers_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_containers_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -54,7 +52,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -165,7 +163,6 @@ class RestorableSqlContainersOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -181,7 +178,6 @@ class RestorableSqlContainersOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_databases_operations.py index 476ffc0656e..04cbe6e7b6e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_databases_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_databases_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -45,7 +43,7 @@ def build_list_request(location: str, instance_id: str, subscription_id: str, ** _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -134,7 +132,6 @@ class RestorableSqlDatabasesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -150,7 +147,6 @@ class RestorableSqlDatabasesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_resources_operations.py index e92ea922bea..073e092e64d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_sql_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -53,7 +51,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -161,7 +159,6 @@ class RestorableSqlResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -177,7 +174,6 @@ class RestorableSqlResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_table_resources_operations.py index 5400a09def5..8bb7de833e2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_table_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_table_resources_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -53,7 +51,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -160,7 +158,6 @@ class RestorableTableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -176,7 +173,6 @@ class RestorableTableResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_tables_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_tables_operations.py index 69d0e15c2d0..a219965b5c3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_tables_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_restorable_tables_operations.py @@ -20,15 +20,13 @@ from azure.core.exceptions import ( ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -53,7 +51,7 @@ def build_list_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -155,7 +153,6 @@ class RestorableTablesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -171,7 +168,6 @@ class RestorableTablesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_service_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_service_operations.py index 5aa131a3adb..38a0459cc8e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_service_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_service_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -48,7 +48,7 @@ def build_list_request(resource_group_name: str, account_name: str, subscription _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -83,7 +83,7 @@ def build_create_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -122,7 +122,7 @@ def build_get_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -158,7 +158,7 @@ def build_delete_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -245,7 +245,6 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -261,7 +260,6 @@ class ServiceOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -297,7 +295,7 @@ class ServiceOperations: service_name: str, create_update_parameters: Union[_models.ServiceResourceCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ServiceResource]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -311,7 +309,7 @@ class ServiceOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ServiceResource]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -333,10 +331,10 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -344,20 +342,22 @@ class ServiceOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ServiceResource", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -475,10 +475,11 @@ class ServiceOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ServiceResource", pipeline_response) + deserialized = self._deserialize("ServiceResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -540,7 +541,6 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -554,16 +554,16 @@ class ServiceOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ServiceResource", pipeline_response) + deserialized = self._deserialize("ServiceResource", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore - def _delete_initial( # pylint: disable=inconsistent-return-statements + def _delete_initial( self, resource_group_name: str, account_name: str, service_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -576,7 +576,7 @@ class ServiceOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_request( resource_group_name=resource_group_name, @@ -587,10 +587,10 @@ class ServiceOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -598,6 +598,10 @@ class ServiceOperations: response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -608,8 +612,12 @@ class ServiceOperations: ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete( @@ -637,7 +645,7 @@ class ServiceOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_initial( # type: ignore + raw_result = self._delete_initial( resource_group_name=resource_group_name, account_name=account_name, service_name=service_name, @@ -647,6 +655,7 @@ class ServiceOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_sql_resources_operations.py index 5eb7aad6dd6..6ac69c082d8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_sql_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_sql_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_sql_databases_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -85,7 +85,7 @@ def build_get_sql_database_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -121,7 +121,7 @@ def build_create_update_sql_database_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -159,7 +159,7 @@ def build_delete_sql_database_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -190,7 +190,7 @@ def build_get_sql_database_throughput_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +226,7 @@ def build_update_sql_database_throughput_request( # pylint: disable=name-too-lo _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -265,7 +265,7 @@ def build_migrate_sql_database_to_autoscale_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -301,7 +301,7 @@ def build_migrate_sql_database_to_manual_throughput_request( # pylint: disable= _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -337,7 +337,7 @@ def build_list_sql_containers_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -378,7 +378,7 @@ def build_get_sql_container_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -420,7 +420,7 @@ def build_create_update_sql_container_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -464,7 +464,7 @@ def build_delete_sql_container_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -501,7 +501,7 @@ def build_get_sql_container_throughput_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -543,7 +543,7 @@ def build_update_sql_container_throughput_request( # pylint: disable=name-too-l _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -588,7 +588,7 @@ def build_migrate_sql_container_to_autoscale_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -630,7 +630,7 @@ def build_migrate_sql_container_to_manual_throughput_request( # pylint: disable _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -667,7 +667,7 @@ def build_list_client_encryption_keys_request( # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -708,7 +708,7 @@ def build_get_client_encryption_key_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -750,7 +750,7 @@ def build_create_update_client_encryption_key_request( # pylint: disable=name-t _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -795,7 +795,7 @@ def build_list_sql_stored_procedures_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -838,7 +838,7 @@ def build_get_sql_stored_procedure_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -882,7 +882,7 @@ def build_create_update_sql_stored_procedure_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -928,7 +928,7 @@ def build_delete_sql_stored_procedure_request( # pylint: disable=name-too-long ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -966,7 +966,7 @@ def build_list_sql_user_defined_functions_request( # pylint: disable=name-too-l _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1009,7 +1009,7 @@ def build_get_sql_user_defined_function_request( # pylint: disable=name-too-lon _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1053,7 +1053,7 @@ def build_create_update_sql_user_defined_function_request( # pylint: disable=na _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1099,7 +1099,7 @@ def build_delete_sql_user_defined_function_request( # pylint: disable=name-too- ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -1137,7 +1137,7 @@ def build_list_sql_triggers_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1180,7 +1180,7 @@ def build_get_sql_trigger_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1224,7 +1224,7 @@ def build_create_update_sql_trigger_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1270,7 +1270,7 @@ def build_delete_sql_trigger_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -1303,7 +1303,7 @@ def build_get_sql_role_definition_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1339,7 +1339,7 @@ def build_create_update_sql_role_definition_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1378,7 +1378,7 @@ def build_delete_sql_role_definition_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1414,7 +1414,7 @@ def build_list_sql_role_definitions_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1449,7 +1449,7 @@ def build_get_sql_role_assignment_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1485,7 +1485,7 @@ def build_create_update_sql_role_assignment_request( # pylint: disable=name-too _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1524,7 +1524,7 @@ def build_delete_sql_role_assignment_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1560,7 +1560,7 @@ def build_list_sql_role_assignments_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -1600,7 +1600,7 @@ def build_retrieve_continuous_backup_information_request( # pylint: disable=nam _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -1694,7 +1694,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -1710,7 +1709,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -1780,7 +1778,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -1794,7 +1791,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1808,7 +1805,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, create_update_sql_database_parameters: Union[_models.SqlDatabaseCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlDatabaseGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1822,7 +1819,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlDatabaseGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1844,10 +1841,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1855,20 +1852,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1988,10 +1987,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response) + deserialized = self._deserialize("SqlDatabaseGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2013,9 +2013,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_database_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_database_initial( self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2028,7 +2028,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_database_request( resource_group_name=resource_group_name, @@ -2039,10 +2039,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2050,6 +2050,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -2060,8 +2064,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_database( @@ -2089,7 +2097,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_database_initial( # type: ignore + raw_result = self._delete_sql_database_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -2099,6 +2107,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -2161,7 +2170,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2175,7 +2183,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2189,7 +2197,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods database_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2203,7 +2211,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2225,10 +2233,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2236,20 +2244,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2370,10 +2380,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2397,7 +2408,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_sql_database_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2410,7 +2421,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_database_to_autoscale_request( resource_group_name=resource_group_name, @@ -2421,10 +2432,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2432,20 +2443,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2488,10 +2501,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2515,7 +2529,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_sql_database_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2528,7 +2542,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_database_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -2539,10 +2553,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2550,20 +2564,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2606,10 +2622,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -2675,7 +2692,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -2691,7 +2707,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -2763,7 +2778,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -2777,7 +2791,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) + deserialized = self._deserialize("SqlContainerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -2792,7 +2806,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, create_update_sql_container_parameters: Union[_models.SqlContainerCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlContainerGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -2806,7 +2820,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlContainerGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -2829,10 +2843,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -2840,20 +2854,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -2983,10 +2999,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlContainerGetResults", pipeline_response) + deserialized = self._deserialize("SqlContainerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3008,9 +3025,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_container_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_container_initial( self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3023,7 +3040,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_container_request( resource_group_name=resource_group_name, @@ -3035,10 +3052,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3046,6 +3063,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -3056,8 +3077,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_container( @@ -3087,7 +3112,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_container_initial( # type: ignore + raw_result = self._delete_sql_container_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -3098,6 +3123,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -3163,7 +3189,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3177,7 +3202,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3192,7 +3217,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3206,7 +3231,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3229,10 +3254,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3240,20 +3265,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -3384,10 +3411,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3411,7 +3439,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_sql_container_to_autoscale_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3424,7 +3452,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_container_to_autoscale_request( resource_group_name=resource_group_name, @@ -3436,10 +3464,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3447,20 +3475,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -3506,10 +3536,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3533,7 +3564,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods def _migrate_sql_container_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, database_name: str, container_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3546,7 +3577,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_sql_container_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -3558,10 +3589,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3569,20 +3600,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -3628,10 +3661,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -3697,7 +3731,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -3713,7 +3746,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -3790,7 +3822,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -3804,7 +3835,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) + deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -3821,7 +3852,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.ClientEncryptionKeyCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.ClientEncryptionKeyGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -3835,7 +3866,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ClientEncryptionKeyGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -3860,10 +3891,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -3871,20 +3902,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -4023,10 +4056,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response) + deserialized = self._deserialize("ClientEncryptionKeyGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -4095,7 +4129,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -4111,7 +4144,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -4192,7 +4224,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -4206,7 +4237,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) + deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4224,7 +4255,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.SqlStoredProcedureCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.SqlStoredProcedureGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4238,7 +4269,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlStoredProcedureGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -4264,10 +4295,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4275,20 +4306,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -4431,10 +4464,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response) + deserialized = self._deserialize("SqlStoredProcedureGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -4456,7 +4490,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_stored_procedure_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_stored_procedure_initial( self, resource_group_name: str, account_name: str, @@ -4464,7 +4498,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, stored_procedure_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4477,7 +4511,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_stored_procedure_request( resource_group_name=resource_group_name, @@ -4490,10 +4524,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4501,6 +4535,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -4511,8 +4549,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_stored_procedure( @@ -4550,7 +4592,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_stored_procedure_initial( # type: ignore + raw_result = self._delete_sql_stored_procedure_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -4562,6 +4604,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -4631,7 +4674,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -4647,7 +4689,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -4728,7 +4769,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -4742,7 +4782,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) + deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -4760,7 +4800,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _models.SqlUserDefinedFunctionCreateUpdateParameters, IO[bytes] ], **kwargs: Any - ) -> Optional[_models.SqlUserDefinedFunctionGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -4774,7 +4814,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlUserDefinedFunctionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -4800,10 +4840,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -4811,20 +4851,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -4970,10 +5012,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response) + deserialized = self._deserialize("SqlUserDefinedFunctionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -4995,7 +5038,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_user_defined_function_initial( # pylint: disable=inconsistent-return-statements,name-too-long + def _delete_sql_user_defined_function_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, @@ -5003,7 +5046,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, user_defined_function_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5016,7 +5059,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_user_defined_function_request( resource_group_name=resource_group_name, @@ -5029,10 +5072,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5040,6 +5083,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -5050,8 +5097,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_user_defined_function( @@ -5089,7 +5140,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_user_defined_function_initial( # type: ignore + raw_result = self._delete_sql_user_defined_function_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -5101,6 +5152,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -5169,7 +5221,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -5185,7 +5236,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -5266,7 +5316,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -5280,7 +5329,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) + deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -5296,7 +5345,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods trigger_name: str, create_update_sql_trigger_parameters: Union[_models.SqlTriggerCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlTriggerGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5310,7 +5359,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlTriggerGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -5334,10 +5383,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5345,20 +5394,22 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -5498,10 +5549,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response) + deserialized = self._deserialize("SqlTriggerGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -5523,7 +5575,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_trigger_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_trigger_initial( self, resource_group_name: str, account_name: str, @@ -5531,7 +5583,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, trigger_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5544,7 +5596,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_trigger_request( resource_group_name=resource_group_name, @@ -5557,10 +5609,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5568,6 +5620,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -5578,8 +5634,12 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_trigger( @@ -5617,7 +5677,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_trigger_initial( # type: ignore + raw_result = self._delete_sql_trigger_initial( resource_group_name=resource_group_name, account_name=account_name, database_name=database_name, @@ -5629,6 +5689,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -5690,7 +5751,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -5704,7 +5764,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -5718,7 +5778,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_sql_role_definition_parameters: Union[_models.SqlRoleDefinitionCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlRoleDefinitionGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5732,7 +5792,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlRoleDefinitionGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -5756,10 +5816,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5767,12 +5827,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -5894,10 +5956,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleDefinitionGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -5919,9 +5982,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_role_definition_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_role_definition_initial( self, role_definition_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -5934,7 +5997,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_role_definition_request( role_definition_id=role_definition_id, @@ -5945,10 +6008,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -5956,11 +6019,19 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_role_definition( @@ -5988,7 +6059,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_role_definition_initial( # type: ignore + raw_result = self._delete_sql_role_definition_initial( role_definition_id=role_definition_id, resource_group_name=resource_group_name, account_name=account_name, @@ -5998,6 +6069,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -6060,7 +6132,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -6076,7 +6147,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -6145,7 +6215,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -6159,7 +6228,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -6173,7 +6242,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods account_name: str, create_update_sql_role_assignment_parameters: Union[_models.SqlRoleAssignmentCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.SqlRoleAssignmentGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -6187,7 +6256,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.SqlRoleAssignmentGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -6211,10 +6280,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -6222,12 +6291,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -6349,10 +6420,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response) + deserialized = self._deserialize("SqlRoleAssignmentGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -6374,9 +6446,9 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_sql_role_assignment_initial( # pylint: disable=inconsistent-return-statements + def _delete_sql_role_assignment_initial( self, role_assignment_id: str, resource_group_name: str, account_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -6389,7 +6461,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_sql_role_assignment_request( role_assignment_id=role_assignment_id, @@ -6400,10 +6472,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -6411,11 +6483,19 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, {}) # type: ignore + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_sql_role_assignment( @@ -6443,7 +6523,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_sql_role_assignment_initial( # type: ignore + raw_result = self._delete_sql_role_assignment_initial( role_assignment_id=role_assignment_id, resource_group_name=resource_group_name, account_name=account_name, @@ -6453,6 +6533,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -6515,7 +6596,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -6531,7 +6611,6 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -6568,7 +6647,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods container_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -6582,7 +6661,7 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -6605,10 +6684,10 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -6616,12 +6695,14 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -6748,10 +6829,11 @@ class SqlResourcesOperations: # pylint: disable=too-many-public-methods params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_table_resources_operations.py index b2f03dd009a..1d546b87320 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_table_resources_operations.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/azure/mgmt/cosmosdb/operations/_table_resources_operations.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------- from io import IOBase import sys -from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload import urllib.parse from azure.core.exceptions import ( @@ -17,13 +17,14 @@ from azure.core.exceptions import ( ResourceExistsError, ResourceNotFoundError, ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, map_error, ) from azure.core.paging import ItemPaged from azure.core.pipeline import PipelineResponse -from azure.core.pipeline.transport import HttpResponse from azure.core.polling import LROPoller, NoPolling, PollingMethod -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from azure.core.tracing.decorator import distributed_trace from azure.core.utils import case_insensitive_dict from azure.mgmt.core.exceptions import ARMErrorFormat @@ -31,7 +32,6 @@ from azure.mgmt.core.polling.arm_polling import ARMPolling from .. import models as _models from .._serialization import Serializer -from .._vendor import _convert_request if sys.version_info >= (3, 9): from collections.abc import MutableMapping @@ -50,7 +50,7 @@ def build_list_tables_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -85,7 +85,7 @@ def build_get_table_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -121,7 +121,7 @@ def build_create_update_table_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -159,7 +159,7 @@ def build_delete_table_request( ) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) # Construct URL _url = kwargs.pop( "template_url", @@ -190,7 +190,7 @@ def build_get_table_throughput_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -226,7 +226,7 @@ def build_update_table_throughput_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -265,7 +265,7 @@ def build_migrate_table_to_autoscale_request( _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -301,7 +301,7 @@ def build_migrate_table_to_manual_throughput_request( # pylint: disable=name-to _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) accept = _headers.pop("Accept", "application/json") # Construct URL @@ -337,7 +337,7 @@ def build_retrieve_continuous_backup_information_request( # pylint: disable=nam _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-05-15")) + api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2024-08-15")) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) accept = _headers.pop("Accept", "application/json") @@ -429,7 +429,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) else: @@ -445,7 +444,6 @@ class TableResourcesOperations: _request = HttpRequest( "GET", urllib.parse.urljoin(next_link, _parsed_next_link.path), params=_next_request_params ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _request.method = "GET" return _request @@ -514,7 +512,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -528,7 +525,7 @@ class TableResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("TableGetResults", pipeline_response) + deserialized = self._deserialize("TableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -542,7 +539,7 @@ class TableResourcesOperations: table_name: str, create_update_table_parameters: Union[_models.TableCreateUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.TableGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -556,7 +553,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.TableGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -578,10 +575,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -589,20 +586,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("TableGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -721,10 +720,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("TableGetResults", pipeline_response) + deserialized = self._deserialize("TableGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -746,9 +746,9 @@ class TableResourcesOperations: self._client, raw_result, get_long_running_output, polling_method # type: ignore ) - def _delete_table_initial( # pylint: disable=inconsistent-return-statements + def _delete_table_initial( self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> None: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -761,7 +761,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[None] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_delete_table_request( resource_group_name=resource_group_name, @@ -772,10 +772,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -783,6 +783,10 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [202, 204]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) @@ -793,8 +797,12 @@ class TableResourcesOperations: ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: - return cls(pipeline_response, None, response_headers) # type: ignore + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore @distributed_trace def begin_delete_table( @@ -822,7 +830,7 @@ class TableResourcesOperations: lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) cont_token: Optional[str] = kwargs.pop("continuation_token", None) if cont_token is None: - raw_result = self._delete_table_initial( # type: ignore + raw_result = self._delete_table_initial( resource_group_name=resource_group_name, account_name=account_name, table_name=table_name, @@ -832,6 +840,7 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): # pylint: disable=inconsistent-return-statements @@ -894,7 +903,6 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) _stream = False @@ -908,7 +916,7 @@ class TableResourcesOperations: map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -922,7 +930,7 @@ class TableResourcesOperations: table_name: str, update_throughput_parameters: Union[_models.ThroughputSettingsUpdateParameters, IO[bytes]], **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -936,7 +944,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -958,10 +966,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -969,20 +977,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1103,10 +1113,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1130,7 +1141,7 @@ class TableResourcesOperations: def _migrate_table_to_autoscale_initial( self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1143,7 +1154,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_table_to_autoscale_request( resource_group_name=resource_group_name, @@ -1154,10 +1165,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1165,20 +1176,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1221,10 +1234,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1248,7 +1262,7 @@ class TableResourcesOperations: def _migrate_table_to_manual_throughput_initial( # pylint: disable=name-too-long self, resource_group_name: str, account_name: str, table_name: str, **kwargs: Any - ) -> Optional[_models.ThroughputSettingsGetResults]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1261,7 +1275,7 @@ class TableResourcesOperations: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) - cls: ClsType[Optional[_models.ThroughputSettingsGetResults]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) _request = build_migrate_table_to_manual_throughput_request( resource_group_name=resource_group_name, @@ -1272,10 +1286,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1283,20 +1297,22 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None response_headers = {} - if response.status_code == 200: - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) - if response.status_code == 202: response_headers["azure-AsyncOperation"] = self._deserialize( "str", response.headers.get("azure-AsyncOperation") ) response_headers["location"] = self._deserialize("str", response.headers.get("location")) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) + if cls: return cls(pipeline_response, deserialized, response_headers) # type: ignore @@ -1339,10 +1355,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response) + deserialized = self._deserialize("ThroughputSettingsGetResults", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized @@ -1371,7 +1388,7 @@ class TableResourcesOperations: table_name: str, location: Union[_models.ContinuousBackupRestoreLocation, IO[bytes]], **kwargs: Any - ) -> Optional[_models.BackupInformation]: + ) -> Iterator[bytes]: error_map: MutableMapping[int, Type[HttpResponseError]] = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, @@ -1385,7 +1402,7 @@ class TableResourcesOperations: api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version)) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[Optional[_models.BackupInformation]] = kwargs.pop("cls", None) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) content_type = content_type or "application/json" _json = None @@ -1407,10 +1424,10 @@ class TableResourcesOperations: headers=_headers, params=_params, ) - _request = _convert_request(_request) _request.url = self._client.format_url(_request.url) - _stream = False + _decompress = kwargs.pop("decompress", True) + _stream = True pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access _request, stream=_stream, **kwargs ) @@ -1418,12 +1435,14 @@ class TableResourcesOperations: response = pipeline_response.http_response if response.status_code not in [200, 202]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass map_error(status_code=response.status_code, response=response, error_map=error_map) raise HttpResponseError(response=response, error_format=ARMErrorFormat) - deserialized = None - if response.status_code == 200: - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = response.stream_download(self._client._pipeline, decompress=_decompress) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore @@ -1540,10 +1559,11 @@ class TableResourcesOperations: params=_params, **kwargs ) + raw_result.http_response.read() # type: ignore kwargs.pop("error_map", None) def get_long_running_output(pipeline_response): - deserialized = self._deserialize("BackupInformation", pipeline_response) + deserialized = self._deserialize("BackupInformation", pipeline_response.http_response) if cls: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/dev_requirements.txt b/sdk/cosmos/azure-mgmt-cosmosdb/dev_requirements.txt index f6457a93d5e..6195bb36ac8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/dev_requirements.txt +++ b/sdk/cosmos/azure-mgmt-cosmosdb/dev_requirements.txt @@ -1 +1,2 @@ --e ../../../tools/azure-sdk-tools \ No newline at end of file +-e ../../../tools/azure-sdk-tools +aiohttp \ No newline at end of file diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_create_update.py index 200bb9f4381..d0d47282c41 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_delete.py index d1b2fbe307f..34a27716c77 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_get.py index 27dcb62f76e..b681200555a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_list.py index a1326d46e23..87facd9d5a8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_autoscale.py index 837cfb5b5a9..d2e66c3d08f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_autoscale.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_manual_throughput.py index 78ae20785ae..5b3d3d03817 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_migrate_to_manual_throughput.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_get.py index 530823fa98e..8ab3ecb8a6a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_update.py index 5f33145b739..3f0dafdde91 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_keyspace_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraKeyspaceThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraKeyspaceThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_create_update.py index b42cac4704d..35dfa23c1af 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -57,6 +55,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_delete.py index 6ca5d7ec1dc..32b49df3d65 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_delete.py @@ -38,6 +38,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_get.py index 32d9abb3d92..4b275f44a62 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_list.py index 6ef7d35dd92..bdc53562a0c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_list.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_autoscale.py index 8226c73d974..1479bdfcde8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_autoscale.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_manual_throughput.py index 86834b8fccf..99f386509bd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_migrate_to_manual_throughput.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_get.py index fa3df823e4e..11c17da25f8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_update.py index e3d42c2f29c..cdc2f364204 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_cassandra_table_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -46,6 +44,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCassandraTableThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCassandraTableThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metric_definitions.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metric_definitions.py index 5e5a1d7586f..8391dd73481 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metric_definitions.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metric_definitions.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCollectionGetMetricDefinitions.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionGetMetricDefinitions.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metrics.py similarity index 71% rename from sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_create_update.py rename to sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metrics.py index 5099f6faee9..00e1592e2da 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_metrics.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -17,7 +15,7 @@ from azure.mgmt.cosmosdb import CosmosDBManagementClient pip install azure-identity pip install azure-mgmt-cosmosdb # USAGE - python cosmos_db_mongo_db_database_create_update.py + python cosmos_db_collection_get_metrics.py Before run the sample, please set the values of the client ID, tenant ID and client secret of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, @@ -32,19 +30,17 @@ def main(): subscription_id="subid", ) - response = client.mongo_db_resources.begin_create_update_mongo_db_database( + response = client.collection.list_metrics( resource_group_name="rg1", account_name="ddb1", - database_name="databaseName", - create_update_mongo_db_database_parameters={ - "location": "West US", - "properties": {"options": {}, "resource": {"id": "databaseName"}}, - "tags": {}, - }, - ).result() - print(response) + database_rid="databaseRid", + collection_rid="collectionRid", + filter="$filter=(name.value eq 'Total Requests') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionGetMetrics.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_usages.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_usages.py index c08b97e0121..1040c774cc3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_usages.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_get_usages.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCollectionGetUsages.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionGetUsages.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_metrics.py similarity index 58% rename from sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_create_update.py rename to sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_metrics.py index 526a629ef7c..930d99cdbbc 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_metrics.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -17,7 +15,7 @@ from azure.mgmt.cosmosdb import CosmosDBManagementClient pip install azure-identity pip install azure-mgmt-cosmosdb # USAGE - python cosmos_db_mongo_db_collection_create_update.py + python cosmos_db_collection_partition_get_metrics.py Before run the sample, please set the values of the client ID, tenant ID and client secret of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, @@ -32,30 +30,17 @@ def main(): subscription_id="subid", ) - response = client.mongo_db_resources.begin_create_update_mongo_db_collection( + response = client.collection_partition.list_metrics( resource_group_name="rg1", account_name="ddb1", - database_name="databaseName", - collection_name="collectionName", - create_update_mongo_db_collection_parameters={ - "location": "West US", - "properties": { - "options": {}, - "resource": { - "id": "collectionName", - "indexes": [ - {"key": {"keys": ["_ts"]}, "options": {"expireAfterSeconds": 100, "unique": True}}, - {"key": {"keys": ["_id"]}}, - ], - "shardKey": {"testKey": "Hash"}, - }, - }, - "tags": {}, - }, - ).result() - print(response) + database_rid="databaseRid", + collection_rid="collectionRid", + filter="$filter=(name.value eq 'Max RUs Per Second') and timeGrain eq duration'PT1M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T23:58:55.2780000Z", + ) + for item in response: + print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionPartitionGetMetrics.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_usages.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_usages.py index a412c4f7288..36916b71410 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_usages.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_get_usages.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBCollectionPartitionGetUsages.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionPartitionGetUsages.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_region_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_region_get_metrics.py new file mode 100644 index 00000000000..926a839123b --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_collection_partition_region_get_metrics.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_collection_partition_region_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.collection_partition_region.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + region="North Europe", + database_rid="databaseRid", + collection_rid="collectionRid", + filter="$filter=(name.value eq 'Max RUs Per Second') and timeGrain eq duration'PT1M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T23:58:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBCollectionPartitionRegionGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_create.py index c4607653577..541cd47f916 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -43,6 +41,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDataTransferServiceCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDataTransferServiceCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_delete.py index a644c65c9a8..738ed7730fe 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDataTransferServiceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDataTransferServiceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_get.py index e9ebfd9c2fb..16f9f942ecb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_data_transfer_service_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDataTransferServiceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDataTransferServiceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_check_name_exists.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_check_name_exists.py index 2b01f481781..d0dbb1be595 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_check_name_exists.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_check_name_exists.py @@ -36,6 +36,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountCheckNameExists.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountCheckNameExists.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_max.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_max.py index 8c3b56db4db..e0167b8456e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_max.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_max.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -94,6 +92,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountCreateMax.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountCreateMax.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_min.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_min.py index 84b0f44c2f3..671f089e0ce 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_min.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_create_min.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -47,6 +45,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountCreateMin.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountCreateMin.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_delete.py index 1988297483b..9fcf29ff833 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_delete.py @@ -36,6 +36,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_failover_priority_change.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_failover_priority_change.py index dc8c9a23ced..b3bad8dcfae 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_failover_priority_change.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_failover_priority_change.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -44,6 +42,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountFailoverPriorityChange.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountFailoverPriorityChange.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get.py index 3127ec5b609..e5b2f52eacd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metric_definitions.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metric_definitions.py index beb62cf9b8e..5e21e104709 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metric_definitions.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metric_definitions.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountGetMetricDefinitions.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountGetMetricDefinitions.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metrics.py new file mode 100644 index 00000000000..9de4c6b601c --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_metrics.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_database_account_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.database_accounts.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + filter="$filter=(name.value eq 'Total Requests') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_usages.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_usages.py index e45a5707ee0..9285dade5bd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_usages.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_get_usages.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountGetUsages.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountGetUsages.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list.py index a2a98a8af1b..8abc0de60db 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list.py @@ -35,6 +35,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_by_resource_group.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_by_resource_group.py index f27a05d7d27..86dc9750f5e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_by_resource_group.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_by_resource_group.py @@ -37,6 +37,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountListByResourceGroup.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountListByResourceGroup.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings.py index 798f9710bfe..cf198f855f4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountListConnectionStrings.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountListConnectionStrings.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings_mongo.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings_mongo.py index 71d2bed8390..8a8f1c91297 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings_mongo.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_connection_strings_mongo.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountListConnectionStringsMongo.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountListConnectionStringsMongo.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_keys.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_keys.py index 30d553720a0..53bca77168e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_keys.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_keys.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountListKeys.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountListKeys.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_read_only_keys.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_read_only_keys.py index 4088f76c079..c4791242cdc 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_read_only_keys.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_list_read_only_keys.py @@ -30,13 +30,13 @@ def main(): subscription_id="subid", ) - response = client.database_accounts.get_read_only_keys( + response = client.database_accounts.list_read_only_keys( resource_group_name="rg1", account_name="ddb1", ) print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountListReadOnlyKeys.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountListReadOnlyKeys.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_offline_region.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_offline_region.py index f7600c1f020..8a98a18dc5a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_offline_region.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_offline_region.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -35,10 +33,10 @@ def main(): client.database_accounts.begin_offline_region( resource_group_name="rg1", account_name="ddb1", - region_parameter_for_offline=[{"region": "North Europe"}], + region_parameter_for_offline={"region": "North Europe"}, ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountOfflineRegion.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountOfflineRegion.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_online_region.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_online_region.py index 7e23e6c9005..085790fd4ab 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_online_region.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_online_region.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -35,10 +33,10 @@ def main(): client.database_accounts.begin_online_region( resource_group_name="rg1", account_name="ddb1", - region_parameter_for_online=[{"region": "North Europe"}], + region_parameter_for_online={"region": "North Europe"}, ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountOnlineRegion.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountOnlineRegion.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_patch.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_patch.py index 3655914fdde..61a83a74408 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_patch.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_patch.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -84,6 +82,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountPatch.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountPatch.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_regenerate_key.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_regenerate_key.py index c796a80c6aa..34924249be3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_regenerate_key.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_regenerate_key.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -39,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseAccountRegenerateKey.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountRegenerateKey.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_region_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_region_get_metrics.py new file mode 100644 index 00000000000..054aecb0ec5 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_account_region_get_metrics.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_database_account_region_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.database_account_region.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + region="North Europe", + filter="$filter=(name.value eq 'Total Requests') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseAccountRegionGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metric_definitions.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metric_definitions.py index 395e9a14bf6..f3c9a3497e7 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metric_definitions.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metric_definitions.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseGetMetricDefinitions.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseGetMetricDefinitions.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metrics.py new file mode 100644 index 00000000000..e1070ed8624 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_metrics.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_database_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.database.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + database_rid="rid", + filter="$filter=(name.value eq 'Total Requests') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_usages.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_usages.py index 221fd9205af..fd159685e89 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_usages.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_database_get_usages.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBDatabaseGetUsages.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBDatabaseGetUsages.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_create.py index 41e56163a93..4af060dbe4e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -43,6 +41,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGraphAPIComputeServiceCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGraphAPIComputeServiceCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_delete.py index 62544413c6a..c709abf4f30 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGraphAPIComputeServiceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGraphAPIComputeServiceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_get.py index 0d1d47bfcc7..f873dace3e8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_graph_api_compute_service_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGraphAPIComputeServiceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGraphAPIComputeServiceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_create_update.py index ea2676712fa..b06f7a36869 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_delete.py index fbc9381aae1..76905102699 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_get.py index 53d2b64ec2a..6241f0a7ab6 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_list.py index 144d200b653..a7dab25ca15 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_autoscale.py index ff03dbe0b47..2b2e6bac60b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_autoscale.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_manual_throughput.py index ed7e103493e..3eb8453ccf3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_migrate_to_manual_throughput.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_get.py index 4ec93bce498..4ff0e9d3b49 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_update.py index 5da2909db85..c081e017b14 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_database_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinDatabaseThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinDatabaseThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_backup_information.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_backup_information.py index 29b521f08b4..20c0f5d633f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_backup_information.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_backup_information.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -42,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphBackupInformation.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphBackupInformation.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_create_update.py index 5e9916aed38..e4fbdaf49e9 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -69,6 +67,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_delete.py index 5e12ad408e2..ea79e9174a2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_delete.py @@ -38,6 +38,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_get.py index 87194001da7..2cc72736436 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_list.py index 761462181ec..15b10d9102a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_list.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_autoscale.py index cf180a2bd40..9b7bbca6fc3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_autoscale.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_manual_throughput.py index 1e3c2738abf..665105bae45 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_migrate_to_manual_throughput.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_get.py index 3b7bde78627..29d842fcf38 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_update.py index fab75e60bcd..21a6e441182 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_gremlin_graph_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -46,6 +44,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBGremlinGraphThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBGremlinGraphThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_get.py index 67ef4c971b2..4af6563dead 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_get.py @@ -36,6 +36,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBLocationGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBLocationGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_list.py index 2a3b121b177..433662fcc1e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_location_list.py @@ -35,6 +35,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBLocationList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBLocationList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_create.py index 67cea973fc5..855788c0630 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -62,6 +60,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_deallocate.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_deallocate.py index 5f87388ff5e..b4bfc71b998 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_deallocate.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_deallocate.py @@ -36,6 +36,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterDeallocate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterDeallocate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_delete.py index 4083942ca49..43227f8ce07 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_delete.py @@ -36,6 +36,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_get.py index 26482ecd7c8..0c998e2e284 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_get.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_resource_group.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_resource_group.py index daa757680df..70b7d264f02 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_resource_group.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_resource_group.py @@ -37,6 +37,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterListByResourceGroup.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterListByResourceGroup.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_subscription.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_subscription.py index 99c9e186ff9..631a5553038 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_subscription.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_list_by_subscription.py @@ -35,6 +35,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterListBySubscription.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterListBySubscription.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_patch.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_patch.py index b70d6d04d6f..d62c6bb8ccb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_patch.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_patch.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -54,6 +52,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterPatch.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterPatch.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_start.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_start.py index 73f5c146e8f..f4d08283e02 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_start.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_cluster_start.py @@ -36,6 +36,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraClusterStart.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraClusterStart.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_command.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_command.py index edfe5a291bd..fe721c134c0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_command.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_command.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -40,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraCommand.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraCommand.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_create.py index b20aba0acf4..0c5e57af5ed 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -48,6 +46,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraDataCenterCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraDataCenterCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_delete.py index 0ea59881670..bccb96b5ddb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraDataCenterDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraDataCenterDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_get.py index cb9dd2be0a6..5e629861c5a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraDataCenterGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraDataCenterGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_list.py index e846da1ba0a..dbc7d061f3b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraDataCenterList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraDataCenterList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_patch.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_patch.py index 3e4831ad22f..0e255a3b61f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_patch.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_data_center_patch.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -48,6 +46,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraDataCenterPatch.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraDataCenterPatch.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_status.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_status.py index c8ae9ae79c8..ca4cf730369 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_status.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_managed_cassandra_status.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBManagedCassandraStatus.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBManagedCassandraStatus.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_create.py index 1d983ba130f..a596b8180ff 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -43,6 +41,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMaterializedViewsBuilderServiceCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMaterializedViewsBuilderServiceCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_delete.py index 2781fd60955..0613a316664 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMaterializedViewsBuilderServiceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMaterializedViewsBuilderServiceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_get.py index 7aded464ab8..ef338c9c5bd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_materialized_views_builder_service_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMaterializedViewsBuilderServiceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMaterializedViewsBuilderServiceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_backup_information.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_backup_information.py index e8611cf2b90..2191f441175 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_backup_information.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_backup_information.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -42,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionBackupInformation.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionBackupInformation.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_delete.py index 5fb499bc226..e4940cbfc87 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_delete.py @@ -38,6 +38,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_get.py index ddd3c0014eb..de76b14b90f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_list.py index a760e7cef2d..1c08ddbc654 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_list.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_autoscale.py index e9a101a9f65..194de6323a2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_autoscale.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_manual_throughput.py index 8e5132be0be..96424baf73b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_migrate_to_manual_throughput.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_get.py index 056db73179d..72f4aa1f7ef 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_update.py index 2c0f0679bde..2805ffda18d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_collection_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -46,6 +44,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBCollectionThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBCollectionThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_delete.py index 55c09b571ef..bdd15b31d2a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_get.py index 737ab265a2f..f3c9556d562 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_list.py index deab6db0715..554466f628a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_autoscale.py index 54df5877835..288e48173fd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_autoscale.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_manual_throughput.py index f8d3e5645b5..738f84656a1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_migrate_to_manual_throughput.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_get.py index 11c6f7b35de..f152859261d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_update.py index ce2bf8d04a5..a48a4ef3aae 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_database_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBDatabaseThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBDatabaseThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_create_update.py index 22dc9a15485..5a49c22f22f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -48,6 +46,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBRoleDefinitionCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBRoleDefinitionCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_delete.py index 5aac813279a..0edb4caf7d6 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBRoleDefinitionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBRoleDefinitionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_get.py index e868661e7f3..d76706e7407 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBRoleDefinitionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBRoleDefinitionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_list.py index f4232483892..a75d0705951 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_role_definition_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBRoleDefinitionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBRoleDefinitionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_create_update.py index 717217702ca..3e5086d3b04 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -50,6 +48,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBUserDefinitionCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBUserDefinitionCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_delete.py index 96b3cb4163d..555fa61ee5c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBUserDefinitionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBUserDefinitionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_get.py index 581476dd0f0..b480429a6b6 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBUserDefinitionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBUserDefinitionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_list.py index 9ae0a0ce833..b449d7457da 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_mongo_db_user_definition_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBMongoDBUserDefinitionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBMongoDBUserDefinitionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_delete.py index 3681e1363ac..b3735e47396 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_delete.py @@ -6,15 +6,10 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from .. import models as _models """ # PREREQUISITES pip install azure-identity @@ -42,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_get.py index f6f3e4f8612..7e1aff6a3df 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_get.py @@ -6,15 +6,10 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from .. import models as _models """ # PREREQUISITES pip install azure-identity @@ -43,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list.py index d75ffa06f72..ed91445294c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list_connection_info.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list_connection_info.py index db4ce3e246d..cc0bdd8a65b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list_connection_info.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_list_connection_info.py @@ -6,15 +6,10 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from .. import models as _models """ # PREREQUISITES pip install azure-identity @@ -43,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceListConnectionInfo.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceListConnectionInfo.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_regenerate_auth_token.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_regenerate_auth_token.py index 0d2b1fb352a..43e76fd7f5a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_regenerate_auth_token.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_regenerate_auth_token.py @@ -6,15 +6,10 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from .. import models as _models """ # PREREQUISITES pip install azure-identity @@ -42,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceRegenerateAuthToken.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceRegenerateAuthToken.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_start.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_start.py index 86f62496550..f2999bfcda4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_start.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_notebook_workspace_start.py @@ -6,15 +6,10 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from .. import models as _models """ # PREREQUISITES pip install azure-identity @@ -42,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBNotebookWorkspaceStart.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBNotebookWorkspaceStart.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_operations_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_operations_list.py index 9e6b09a0312..716a3a781e7 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_operations_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_operations_list.py @@ -35,6 +35,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBOperationsList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBOperationsList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_get_metrics.py new file mode 100644 index 00000000000..869cc5071b1 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_get_metrics.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_percentile_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.percentile.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + filter="$filter=(name.value eq 'Probabilistic Bounded Staleness') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPercentileGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_source_target_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_source_target_get_metrics.py new file mode 100644 index 00000000000..58382d29937 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_source_target_get_metrics.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_percentile_source_target_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.percentile_source_target.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + source_region="West Central US", + target_region="East US", + filter="$filter=(name.value eq 'Probabilistic Bounded Staleness') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPercentileSourceTargetGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_target_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_target_get_metrics.py new file mode 100644 index 00000000000..7c8eee86694 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_percentile_target_get_metrics.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_percentile_target_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.percentile_target.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + target_region="East US", + filter="$filter=(name.value eq 'Probabilistic Bounded Staleness') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPercentileTargetGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_delete.py index 5d1659ee575..752c0e68c5c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateEndpointConnectionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateEndpointConnectionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_get.py index 139e1c76b4f..ebaf856abae 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateEndpointConnectionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateEndpointConnectionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_list_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_list_get.py index 7a6b641fde3..124396447b4 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_list_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_list_get.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateEndpointConnectionListGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateEndpointConnectionListGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_update.py index 54dc86e477d..b926d3d7684 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_endpoint_connection_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -48,6 +46,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateEndpointConnectionUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateEndpointConnectionUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_get.py index 5fb2c821b0e..10ed98d9026 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateLinkResourceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateLinkResourceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_list_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_list_get.py index 215168d56ab..a6fd705e1e2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_list_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_private_link_resource_list_get.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBPrivateLinkResourceListGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPrivateLinkResourceListGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_region_collection_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_region_collection_get_metrics.py new file mode 100644 index 00000000000..ca955c3e219 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_region_collection_get_metrics.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_db_region_collection_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.collection_region.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + region="North Europe", + database_rid="databaseRid", + collection_rid="collectionRid", + filter="$filter=(name.value eq 'Total Requests') and timeGrain eq duration'PT5M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T00:13:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRegionCollectionGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_get.py index 26818653124..dc5528116e1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_get.py @@ -37,6 +37,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableDatabaseAccountGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableDatabaseAccountGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_list.py index 2b1e0576799..746f47f94c0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_list.py @@ -37,6 +37,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableDatabaseAccountList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableDatabaseAccountList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_no_location_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_no_location_list.py index c60095852d2..edd799393be 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_no_location_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_database_account_no_location_list.py @@ -35,6 +35,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableDatabaseAccountNoLocationList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableDatabaseAccountNoLocationList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_database_list.py index 87f259b15db..5b53fb91261 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableGremlinDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableGremlinDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_graph_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_graph_list.py index c1951672555..bc9e7330028 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_graph_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_graph_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableGremlinGraphList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableGremlinGraphList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_resource_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_resource_list.py index 22871f03529..1f2333ac292 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_resource_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_gremlin_resource_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableGremlinResourceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableGremlinResourceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_collection_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_collection_list.py index b0eb71a5d66..a0539c3228f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_collection_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_collection_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableMongodbCollectionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableMongodbCollectionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_database_list.py index 163a5e0c803..c24f0b9e9cb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableMongodbDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableMongodbDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_resource_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_resource_list.py index 0a64a568976..ce1bb81f82a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_resource_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_mongodb_resource_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableMongodbResourceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableMongodbResourceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_container_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_container_list.py index bfc5b4f1e12..916c547af30 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_container_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_container_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableSqlContainerList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableSqlContainerList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_database_list.py index d54c3af6a8c..8f22037ec93 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableSqlDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableSqlDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_resource_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_resource_list.py index 8d268c914f5..0cc1162b12a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_resource_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_sql_resource_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableSqlResourceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableSqlResourceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_list.py index f41aaf5175f..87fe42d37a8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableTableList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableTableList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_resource_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_resource_list.py index 99ff563e2c2..c8015d103a0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_resource_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restorable_table_resource_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestorableTableResourceList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestorableTableResourceList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restore_database_account_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restore_database_account_create_update.py index 7871292ca6c..8a71713e0df 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restore_database_account_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_restore_database_account_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -61,6 +59,7 @@ def main(): "restoreMode": "PointInTime", "restoreSource": "/subscriptions/subid/providers/Microsoft.DocumentDB/locations/westus/restorableDatabaseAccounts/1a97b4bb-f6a0-430e-ade1-638d781830cc", "restoreTimestampInUtc": "2021-03-11T22:05:09Z", + "restoreWithTtlDisabled": False, }, }, "tags": {}, @@ -69,6 +68,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBRestoreDatabaseAccountCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBRestoreDatabaseAccountCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_services_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_services_list.py index 7d32450cb65..1df6daa2303 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_services_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_services_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBServicesList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBServicesList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_create_update.py index cd01d556098..e9796b4f409 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -56,6 +54,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlClientEncryptionKeyCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlClientEncryptionKeyCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_get.py index 3169ab1d6e8..126099a92f3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_key_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlClientEncryptionKeyGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlClientEncryptionKeyGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_keys_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_keys_list.py index a82314d4ab0..067f63eca79 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_keys_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_client_encryption_keys_list.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlClientEncryptionKeysList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlClientEncryptionKeysList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_backup_information.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_backup_information.py index 9517e88d420..03c12a9ee75 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_backup_information.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_backup_information.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -42,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerBackupInformation.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerBackupInformation.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_create_update.py index 074e5e4186f..278a3af1e18 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -81,6 +79,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_delete.py index eaa2aadcea1..e91d0ccfbd8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_delete.py @@ -38,6 +38,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_get.py index e8f1e8da435..4285fed22e0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_list.py index 9b90da9be4d..746c40215ec 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_list.py @@ -39,6 +39,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_autoscale.py index 02e1c9f1262..3a69f8b5880 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_autoscale.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_manual_throughput.py index e731b8e3da4..23e3e06a20d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_migrate_to_manual_throughput.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_get.py index b6ba4b97b03..6e309952cc9 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_get.py @@ -39,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_update.py index aef6d3f10d5..85f75a834c8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_container_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -46,6 +44,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlContainerThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlContainerThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_create_update.py index 8c009bd5021..de196bad91f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_delete.py index d5171e6c8f0..9f379eb1693 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_get.py index 7547633e948..6e66bc4d29e 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_list.py index 990bdd43a4e..26f629271ae 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_autoscale.py index b2fa0834694..18686cbed11 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_autoscale.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_manual_throughput.py index 13e1d7f33b2..a1d3cd64bdf 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_migrate_to_manual_throughput.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_get.py index b1e3921abd8..b7677b0998a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_update.py index a10df37d9cf..5455c735110 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_database_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlDatabaseThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlDatabaseThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_create_update.py index daa51d0e1ef..929fdfd09cf 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -47,6 +45,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleAssignmentCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleAssignmentCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_delete.py index e7585d4ed22..1d7fa12bdc3 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleAssignmentDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleAssignmentDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_get.py index 0b0de5acc6a..4a237607316 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleAssignmentGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleAssignmentGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_list.py index 9d165ec8134..d1beb0a7fed 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_assignment_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleAssignmentList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleAssignmentList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_create_update.py index f2bcf93369c..c5be8aff953 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -59,6 +57,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleDefinitionCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleDefinitionCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_delete.py index 36d82d24231..6f5c2133e5f 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleDefinitionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleDefinitionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_get.py index e2b4bc45fda..3d1ccbbf992 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleDefinitionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleDefinitionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_list.py index 8037520ffca..e4fdded0070 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_role_definition_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlRoleDefinitionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlRoleDefinitionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_create_update.py index bd378dc11b2..e7510ec754d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlStoredProcedureCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlStoredProcedureCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_delete.py index 8f63398b687..a80b1e91d3c 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_delete.py @@ -39,6 +39,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlStoredProcedureDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlStoredProcedureDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_get.py index 8c7216ab3ac..b353f00abc1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_get.py @@ -40,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlStoredProcedureGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlStoredProcedureGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_list.py index 35acf392560..70140145b28 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_stored_procedure_list.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlStoredProcedureList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlStoredProcedureList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_create_update.py index 5c3cbbc8a75..fa963d27249 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -53,6 +51,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlTriggerCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlTriggerCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_delete.py index 4a5c9f7bcc7..f28acbf7a97 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_delete.py @@ -39,6 +39,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlTriggerDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlTriggerDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_get.py index c0903877e09..75955381d7d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_get.py @@ -40,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlTriggerGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlTriggerGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_list.py index c0b51af78d8..0a7a4aa1e70 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_trigger_list.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlTriggerList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlTriggerList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_create_update.py index 88bb3e114ec..0e3a9290bf0 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlUserDefinedFunctionCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlUserDefinedFunctionCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_delete.py index 05fe318ace8..129d127e718 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_delete.py @@ -39,6 +39,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlUserDefinedFunctionDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlUserDefinedFunctionDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_get.py index f5fb0f4e86e..7074249c5d8 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_get.py @@ -40,6 +40,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlUserDefinedFunctionGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlUserDefinedFunctionGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_list.py index b4b0e9a064e..51f6bb045cd 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_sql_user_defined_function_list.py @@ -40,6 +40,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBSqlUserDefinedFunctionList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBSqlUserDefinedFunctionList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_backup_information.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_backup_information.py index 2a2368f94bf..b1620a5ee84 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_backup_information.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_backup_information.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -41,6 +39,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableBackupInformation.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableBackupInformation.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_create_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_create_update.py index a6725d62e7b..52f16de61bb 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_create_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_create_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableCreateUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableCreateUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_delete.py index 020f3fdc958..06a400f2a7a 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_get.py index c03b03be7ba..901ae4c96b5 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_list.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_list.py index c4390a5d2bc..bb4221fdf20 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_list.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_list.py @@ -38,6 +38,6 @@ def main(): print(item) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableList.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableList.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_autoscale.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_autoscale.py index 93a3e4c93cb..ae97bf025cf 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_autoscale.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_autoscale.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableMigrateToAutoscale.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableMigrateToAutoscale.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_manual_throughput.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_manual_throughput.py index 98726063d6d..be127cfac6d 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_manual_throughput.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_migrate_to_manual_throughput.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableMigrateToManualThroughput.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableMigrateToManualThroughput.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_get.py index 97ec0b97f49..def51d51e76 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableThroughputGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableThroughputGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_update.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_update.py index 41ad2235d4a..bf3e427d5b1 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_update.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_db_table_throughput_update.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -45,6 +43,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/CosmosDBTableThroughputUpdate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBTableThroughputUpdate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_get_metrics.py new file mode 100644 index 00000000000..ac158a963b7 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_get_metrics.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_dbp_key_range_id_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.partition_key_range_id.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + database_rid="databaseRid", + collection_rid="collectionRid", + partition_key_range_id="0", + filter="$filter=(name.value eq 'Max RUs Per Second') and timeGrain eq duration'PT1M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T23:58:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPKeyRangeIdGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_region_get_metrics.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_region_get_metrics.py new file mode 100644 index 00000000000..5d6fa67d98f --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/cosmos_dbp_key_range_id_region_get_metrics.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from azure.identity import DefaultAzureCredential + +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +""" +# PREREQUISITES + pip install azure-identity + pip install azure-mgmt-cosmosdb +# USAGE + python cosmos_dbp_key_range_id_region_get_metrics.py + + Before run the sample, please set the values of the client ID, tenant ID and client secret + of the AAD application as environment variables: AZURE_CLIENT_ID, AZURE_TENANT_ID, + AZURE_CLIENT_SECRET. For more info about how to get the value, please see: + https://docs.microsoft.com/azure/active-directory/develop/howto-create-service-principal-portal +""" + + +def main(): + client = CosmosDBManagementClient( + credential=DefaultAzureCredential(), + subscription_id="subid", + ) + + response = client.partition_key_range_id_region.list_metrics( + resource_group_name="rg1", + account_name="ddb1", + region="West US", + database_rid="databaseRid", + collection_rid="collectionRid", + partition_key_range_id="0", + filter="$filter=(name.value eq 'Max RUs Per Second') and timeGrain eq duration'PT1M' and startTime eq '2017-11-19T23:53:55.2780000Z' and endTime eq '2017-11-20T23:58:55.2780000Z", + ) + for item in response: + print(item) + + +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/CosmosDBPKeyRangeIdRegionGetMetrics.json +if __name__ == "__main__": + main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_create.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_create.py index c118175b00a..b189ad90ab2 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_create.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_create.py @@ -6,8 +6,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, IO, Union - from azure.identity import DefaultAzureCredential from azure.mgmt.cosmosdb import CosmosDBManagementClient @@ -48,6 +46,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceCreate.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceCreate.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_delete.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_delete.py index f95fc7e2841..e4fa8b45178 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_delete.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_delete.py @@ -37,6 +37,6 @@ def main(): ).result() -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceDelete.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceDelete.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_get.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_get.py index 21f9f28deef..1898e56cf9b 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_get.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_samples/services/sqldedicatedgateway/cosmos_db_sql_dedicated_gateway_service_get.py @@ -38,6 +38,6 @@ def main(): print(response) -# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-05-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceGet.json +# x-ms-original-file: specification/cosmos-db/resource-manager/Microsoft.DocumentDB/stable/2024-08-15/examples/services/sqldedicatedgateway/CosmosDBSqlDedicatedGatewayServiceGet.json if __name__ == "__main__": main() diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/conftest.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/conftest.py new file mode 100644 index 00000000000..c6d1ee70d05 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/conftest.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import os +import pytest +from dotenv import load_dotenv +from devtools_testutils import ( + test_proxy, + add_general_regex_sanitizer, + add_body_key_sanitizer, + add_header_regex_sanitizer, +) + +load_dotenv() + + +# aovid record sensitive identity information in recordings +@pytest.fixture(scope="session", autouse=True) +def add_sanitizers(test_proxy): + cosmosdbmanagement_subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_tenant_id = os.environ.get("AZURE_TENANT_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_client_id = os.environ.get("AZURE_CLIENT_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_client_secret = os.environ.get("AZURE_CLIENT_SECRET", "00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_subscription_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_tenant_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_client_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_client_secret, value="00000000-0000-0000-0000-000000000000") + + add_header_regex_sanitizer(key="Set-Cookie", value="[set-cookie;]") + add_header_regex_sanitizer(key="Cookie", value="cookie;") + add_body_key_sanitizer(json_path="$..access_token", value="access_token") diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations.py new file mode 100644 index 00000000000..949bc725aeb --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraClustersOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_subscription(self, resource_group): + response = self.client.cassandra_clusters.list_by_subscription( + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_resource_group(self, resource_group): + response = self.client.cassandra_clusters.list_by_resource_group( + resource_group_name=resource_group.name, + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.cassandra_clusters.get( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.cassandra_clusters.begin_delete( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update(self, resource_group): + response = self.client.cassandra_clusters.begin_create_update( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "id": "str", + "identity": {"principalId": "str", "tenantId": "str", "type": "str"}, + "location": "str", + "name": "str", + "properties": { + "authenticationMethod": "str", + "azureConnectionMethod": "str", + "cassandraAuditLoggingEnabled": bool, + "cassandraVersion": "str", + "clientCertificates": [{"pem": "str"}], + "clusterNameOverride": "str", + "deallocated": bool, + "delegatedManagementSubnetId": "str", + "externalGossipCertificates": [{"pem": "str"}], + "externalSeedNodes": [{"ipAddress": "str"}], + "gossipCertificates": [{"pem": "str"}], + "hoursBetweenBackups": 0, + "initialCassandraAdminPassword": "str", + "privateLinkResourceId": "str", + "prometheusEndpoint": {"ipAddress": "str"}, + "provisionError": {"additionalErrorInfo": "str", "code": "str", "message": "str", "target": "str"}, + "provisioningState": "str", + "repairEnabled": bool, + "restoreFromBackupId": "str", + "seedNodes": [{"ipAddress": "str"}], + }, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update(self, resource_group): + response = self.client.cassandra_clusters.begin_update( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "id": "str", + "identity": {"principalId": "str", "tenantId": "str", "type": "str"}, + "location": "str", + "name": "str", + "properties": { + "authenticationMethod": "str", + "azureConnectionMethod": "str", + "cassandraAuditLoggingEnabled": bool, + "cassandraVersion": "str", + "clientCertificates": [{"pem": "str"}], + "clusterNameOverride": "str", + "deallocated": bool, + "delegatedManagementSubnetId": "str", + "externalGossipCertificates": [{"pem": "str"}], + "externalSeedNodes": [{"ipAddress": "str"}], + "gossipCertificates": [{"pem": "str"}], + "hoursBetweenBackups": 0, + "initialCassandraAdminPassword": "str", + "privateLinkResourceId": "str", + "prometheusEndpoint": {"ipAddress": "str"}, + "provisionError": {"additionalErrorInfo": "str", "code": "str", "message": "str", "target": "str"}, + "provisioningState": "str", + "repairEnabled": bool, + "restoreFromBackupId": "str", + "seedNodes": [{"ipAddress": "str"}], + }, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_invoke_command(self, resource_group): + response = self.client.cassandra_clusters.begin_invoke_command( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "command": "str", + "host": "str", + "arguments": {"str": "str"}, + "cassandra-stop-start": bool, + "readwrite": bool, + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_deallocate(self, resource_group): + response = self.client.cassandra_clusters.begin_deallocate( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_start(self, resource_group): + response = self.client.cassandra_clusters.begin_start( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_status(self, resource_group): + response = self.client.cassandra_clusters.status( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations_async.py new file mode 100644 index 00000000000..867916e7057 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_clusters_operations_async.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraClustersOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_subscription(self, resource_group): + response = self.client.cassandra_clusters.list_by_subscription( + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_resource_group(self, resource_group): + response = self.client.cassandra_clusters.list_by_resource_group( + resource_group_name=resource_group.name, + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.cassandra_clusters.get( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_delete( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_create_update( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "id": "str", + "identity": {"principalId": "str", "tenantId": "str", "type": "str"}, + "location": "str", + "name": "str", + "properties": { + "authenticationMethod": "str", + "azureConnectionMethod": "str", + "cassandraAuditLoggingEnabled": bool, + "cassandraVersion": "str", + "clientCertificates": [{"pem": "str"}], + "clusterNameOverride": "str", + "deallocated": bool, + "delegatedManagementSubnetId": "str", + "externalGossipCertificates": [{"pem": "str"}], + "externalSeedNodes": [{"ipAddress": "str"}], + "gossipCertificates": [{"pem": "str"}], + "hoursBetweenBackups": 0, + "initialCassandraAdminPassword": "str", + "privateLinkResourceId": "str", + "prometheusEndpoint": {"ipAddress": "str"}, + "provisionError": { + "additionalErrorInfo": "str", + "code": "str", + "message": "str", + "target": "str", + }, + "provisioningState": "str", + "repairEnabled": bool, + "restoreFromBackupId": "str", + "seedNodes": [{"ipAddress": "str"}], + }, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_update( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "id": "str", + "identity": {"principalId": "str", "tenantId": "str", "type": "str"}, + "location": "str", + "name": "str", + "properties": { + "authenticationMethod": "str", + "azureConnectionMethod": "str", + "cassandraAuditLoggingEnabled": bool, + "cassandraVersion": "str", + "clientCertificates": [{"pem": "str"}], + "clusterNameOverride": "str", + "deallocated": bool, + "delegatedManagementSubnetId": "str", + "externalGossipCertificates": [{"pem": "str"}], + "externalSeedNodes": [{"ipAddress": "str"}], + "gossipCertificates": [{"pem": "str"}], + "hoursBetweenBackups": 0, + "initialCassandraAdminPassword": "str", + "privateLinkResourceId": "str", + "prometheusEndpoint": {"ipAddress": "str"}, + "provisionError": { + "additionalErrorInfo": "str", + "code": "str", + "message": "str", + "target": "str", + }, + "provisioningState": "str", + "repairEnabled": bool, + "restoreFromBackupId": "str", + "seedNodes": [{"ipAddress": "str"}], + }, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_invoke_command(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_invoke_command( + resource_group_name=resource_group.name, + cluster_name="str", + body={ + "command": "str", + "host": "str", + "arguments": {"str": "str"}, + "cassandra-stop-start": bool, + "readwrite": bool, + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_deallocate(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_deallocate( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_start(self, resource_group): + response = await ( + await self.client.cassandra_clusters.begin_start( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_status(self, resource_group): + response = await self.client.cassandra_clusters.status( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations.py new file mode 100644 index 00000000000..2aa988b82d3 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraDataCentersOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.cassandra_data_centers.list( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.cassandra_data_centers.get( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.cassandra_data_centers.begin_delete( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update(self, resource_group): + response = self.client.cassandra_data_centers.begin_create_update( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + body={ + "id": "str", + "name": "str", + "properties": { + "authenticationMethodLdapProperties": { + "connectionTimeoutInMs": 0, + "searchBaseDistinguishedName": "str", + "searchFilterTemplate": "str", + "serverCertificates": [{"pem": "str"}], + "serverHostname": "str", + "serverPort": 0, + "serviceUserDistinguishedName": "str", + "serviceUserPassword": "str", + }, + "availabilityZone": bool, + "backupStorageCustomerKeyUri": "str", + "base64EncodedCassandraYamlFragment": "str", + "dataCenterLocation": "str", + "deallocated": bool, + "delegatedSubnetId": "str", + "diskCapacity": 0, + "diskSku": "str", + "managedDiskCustomerKeyUri": "str", + "nodeCount": 0, + "privateEndpointIpAddress": "str", + "provisionError": {"additionalErrorInfo": "str", "code": "str", "message": "str", "target": "str"}, + "provisioningState": "str", + "seedNodes": [{"ipAddress": "str"}], + "sku": "str", + }, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update(self, resource_group): + response = self.client.cassandra_data_centers.begin_update( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + body={ + "id": "str", + "name": "str", + "properties": { + "authenticationMethodLdapProperties": { + "connectionTimeoutInMs": 0, + "searchBaseDistinguishedName": "str", + "searchFilterTemplate": "str", + "serverCertificates": [{"pem": "str"}], + "serverHostname": "str", + "serverPort": 0, + "serviceUserDistinguishedName": "str", + "serviceUserPassword": "str", + }, + "availabilityZone": bool, + "backupStorageCustomerKeyUri": "str", + "base64EncodedCassandraYamlFragment": "str", + "dataCenterLocation": "str", + "deallocated": bool, + "delegatedSubnetId": "str", + "diskCapacity": 0, + "diskSku": "str", + "managedDiskCustomerKeyUri": "str", + "nodeCount": 0, + "privateEndpointIpAddress": "str", + "provisionError": {"additionalErrorInfo": "str", "code": "str", "message": "str", "target": "str"}, + "provisioningState": "str", + "seedNodes": [{"ipAddress": "str"}], + "sku": "str", + }, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations_async.py new file mode 100644 index 00000000000..96bea692e54 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_data_centers_operations_async.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraDataCentersOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.cassandra_data_centers.list( + resource_group_name=resource_group.name, + cluster_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.cassandra_data_centers.get( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.cassandra_data_centers.begin_delete( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update(self, resource_group): + response = await ( + await self.client.cassandra_data_centers.begin_create_update( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + body={ + "id": "str", + "name": "str", + "properties": { + "authenticationMethodLdapProperties": { + "connectionTimeoutInMs": 0, + "searchBaseDistinguishedName": "str", + "searchFilterTemplate": "str", + "serverCertificates": [{"pem": "str"}], + "serverHostname": "str", + "serverPort": 0, + "serviceUserDistinguishedName": "str", + "serviceUserPassword": "str", + }, + "availabilityZone": bool, + "backupStorageCustomerKeyUri": "str", + "base64EncodedCassandraYamlFragment": "str", + "dataCenterLocation": "str", + "deallocated": bool, + "delegatedSubnetId": "str", + "diskCapacity": 0, + "diskSku": "str", + "managedDiskCustomerKeyUri": "str", + "nodeCount": 0, + "privateEndpointIpAddress": "str", + "provisionError": { + "additionalErrorInfo": "str", + "code": "str", + "message": "str", + "target": "str", + }, + "provisioningState": "str", + "seedNodes": [{"ipAddress": "str"}], + "sku": "str", + }, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update(self, resource_group): + response = await ( + await self.client.cassandra_data_centers.begin_update( + resource_group_name=resource_group.name, + cluster_name="str", + data_center_name="str", + body={ + "id": "str", + "name": "str", + "properties": { + "authenticationMethodLdapProperties": { + "connectionTimeoutInMs": 0, + "searchBaseDistinguishedName": "str", + "searchFilterTemplate": "str", + "serverCertificates": [{"pem": "str"}], + "serverHostname": "str", + "serverPort": 0, + "serviceUserDistinguishedName": "str", + "serviceUserPassword": "str", + }, + "availabilityZone": bool, + "backupStorageCustomerKeyUri": "str", + "base64EncodedCassandraYamlFragment": "str", + "dataCenterLocation": "str", + "deallocated": bool, + "delegatedSubnetId": "str", + "diskCapacity": 0, + "diskSku": "str", + "managedDiskCustomerKeyUri": "str", + "nodeCount": 0, + "privateEndpointIpAddress": "str", + "provisionError": { + "additionalErrorInfo": "str", + "code": "str", + "message": "str", + "target": "str", + }, + "provisioningState": "str", + "seedNodes": [{"ipAddress": "str"}], + "sku": "str", + }, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations.py new file mode 100644 index 00000000000..3f5970be1f6 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_cassandra_keyspaces(self, resource_group): + response = self.client.cassandra_resources.list_cassandra_keyspaces( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_cassandra_keyspace(self, resource_group): + response = self.client.cassandra_resources.get_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_cassandra_keyspace(self, resource_group): + response = self.client.cassandra_resources.begin_create_update_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + create_update_cassandra_keyspace_parameters={ + "resource": {"id": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_cassandra_keyspace(self, resource_group): + response = self.client.cassandra_resources.begin_delete_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_cassandra_keyspace_throughput(self, resource_group): + response = self.client.cassandra_resources.get_cassandra_keyspace_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_cassandra_keyspace_throughput(self, resource_group): + response = self.client.cassandra_resources.begin_update_cassandra_keyspace_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_cassandra_keyspace_to_autoscale(self, resource_group): + response = self.client.cassandra_resources.begin_migrate_cassandra_keyspace_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_cassandra_keyspace_to_manual_throughput(self, resource_group): + response = self.client.cassandra_resources.begin_migrate_cassandra_keyspace_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_cassandra_tables(self, resource_group): + response = self.client.cassandra_resources.list_cassandra_tables( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_cassandra_table(self, resource_group): + response = self.client.cassandra_resources.get_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_cassandra_table(self, resource_group): + response = self.client.cassandra_resources.begin_create_update_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + create_update_cassandra_table_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "defaultTtl": 0, + "schema": { + "clusterKeys": [{"name": "str", "orderBy": "str"}], + "columns": [{"name": "str", "type": "str"}], + "partitionKeys": [{"name": "str"}], + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_cassandra_table(self, resource_group): + response = self.client.cassandra_resources.begin_delete_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_cassandra_table_throughput(self, resource_group): + response = self.client.cassandra_resources.get_cassandra_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_cassandra_table_throughput(self, resource_group): + response = self.client.cassandra_resources.begin_update_cassandra_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_cassandra_table_to_autoscale(self, resource_group): + response = self.client.cassandra_resources.begin_migrate_cassandra_table_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_cassandra_table_to_manual_throughput(self, resource_group): + response = self.client.cassandra_resources.begin_migrate_cassandra_table_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations_async.py new file mode 100644 index 00000000000..1013dcf9473 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_cassandra_resources_operations_async.py @@ -0,0 +1,319 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCassandraResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_cassandra_keyspaces(self, resource_group): + response = self.client.cassandra_resources.list_cassandra_keyspaces( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_cassandra_keyspace(self, resource_group): + response = await self.client.cassandra_resources.get_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_cassandra_keyspace(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_create_update_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + create_update_cassandra_keyspace_parameters={ + "resource": {"id": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_cassandra_keyspace(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_delete_cassandra_keyspace( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_cassandra_keyspace_throughput(self, resource_group): + response = await self.client.cassandra_resources.get_cassandra_keyspace_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_cassandra_keyspace_throughput(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_update_cassandra_keyspace_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_cassandra_keyspace_to_autoscale(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_migrate_cassandra_keyspace_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_cassandra_keyspace_to_manual_throughput(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_migrate_cassandra_keyspace_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_cassandra_tables(self, resource_group): + response = self.client.cassandra_resources.list_cassandra_tables( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_cassandra_table(self, resource_group): + response = await self.client.cassandra_resources.get_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_cassandra_table(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_create_update_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + create_update_cassandra_table_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "defaultTtl": 0, + "schema": { + "clusterKeys": [{"name": "str", "orderBy": "str"}], + "columns": [{"name": "str", "type": "str"}], + "partitionKeys": [{"name": "str"}], + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_cassandra_table(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_delete_cassandra_table( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_cassandra_table_throughput(self, resource_group): + response = await self.client.cassandra_resources.get_cassandra_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_cassandra_table_throughput(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_update_cassandra_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_cassandra_table_to_autoscale(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_migrate_cassandra_table_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_cassandra_table_to_manual_throughput(self, resource_group): + response = await ( + await self.client.cassandra_resources.begin_migrate_cassandra_table_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + keyspace_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations.py new file mode 100644 index 00000000000..e59bd76e9ab --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.collection.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_usages(self, resource_group): + response = self.client.collection.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metric_definitions(self, resource_group): + response = self.client.collection.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations_async.py new file mode 100644 index 00000000000..e63af0204e2 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_operations_async.py @@ -0,0 +1,63 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.collection.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_usages(self, resource_group): + response = self.client.collection.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metric_definitions(self, resource_group): + response = self.client.collection.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations.py new file mode 100644 index 00000000000..c3e78965d93 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionPartitionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.collection_partition.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_usages(self, resource_group): + response = self.client.collection_partition.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations_async.py new file mode 100644 index 00000000000..f887ec685a0 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_operations_async.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionPartitionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.collection_partition.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_usages(self, resource_group): + response = self.client.collection_partition.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations.py new file mode 100644 index 00000000000..3c9879f143c --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionPartitionRegionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.collection_partition_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations_async.py new file mode 100644 index 00000000000..7f4aa459064 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_partition_region_operations_async.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionPartitionRegionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.collection_partition_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations.py new file mode 100644 index 00000000000..ba24b84db2b --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionRegionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.collection_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations_async.py new file mode 100644 index 00000000000..c6084c38ea9 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_collection_region_operations_async.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementCollectionRegionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.collection_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations.py new file mode 100644 index 00000000000..dceaa34d4c5 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseAccountRegionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.database_account_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations_async.py new file mode 100644 index 00000000000..b244a424639 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_account_region_operations_async.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseAccountRegionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.database_account_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations.py new file mode 100644 index 00000000000..a455e4dbb29 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations.py @@ -0,0 +1,381 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseAccountsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.database_accounts.get( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update(self, resource_group): + response = self.client.database_accounts.begin_update( + resource_group_name=resource_group.name, + account_name="str", + update_parameters={ + "analyticalStorageConfiguration": {"schemaType": "str"}, + "apiProperties": {"serverVersion": "str"}, + "backupPolicy": "backup_policy", + "capabilities": [{"name": "str"}], + "capacity": {"totalThroughputLimit": 0}, + "connectorOffer": "str", + "consistencyPolicy": { + "defaultConsistencyLevel": "str", + "maxIntervalInSeconds": 0, + "maxStalenessPrefix": 0, + }, + "cors": [ + { + "allowedOrigins": "str", + "allowedHeaders": "str", + "allowedMethods": "str", + "exposedHeaders": "str", + "maxAgeInSeconds": 0, + } + ], + "customerManagedKeyStatus": "str", + "defaultIdentity": "str", + "disableKeyBasedMetadataWriteAccess": bool, + "disableLocalAuth": bool, + "enableAnalyticalStorage": bool, + "enableAutomaticFailover": bool, + "enableBurstCapacity": bool, + "enableCassandraConnector": bool, + "enableFreeTier": bool, + "enableMultipleWriteLocations": bool, + "enablePartitionMerge": bool, + "identity": { + "principalId": "str", + "tenantId": "str", + "type": "str", + "userAssignedIdentities": {"str": {"clientId": "str", "principalId": "str"}}, + }, + "ipRules": [{"ipAddressOrRange": "str"}], + "isVirtualNetworkFilterEnabled": bool, + "keyVaultKeyUri": "str", + "keysMetadata": { + "primaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "primaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + }, + "location": "str", + "locations": [ + { + "documentEndpoint": "str", + "failoverPriority": 0, + "id": "str", + "isZoneRedundant": bool, + "locationName": "str", + "provisioningState": "str", + } + ], + "minimalTlsVersion": "str", + "networkAclBypass": "str", + "networkAclBypassResourceIds": ["str"], + "publicNetworkAccess": "str", + "tags": {"str": "str"}, + "virtualNetworkRules": [{"id": "str", "ignoreMissingVNetServiceEndpoint": bool}], + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_or_update(self, resource_group): + response = self.client.database_accounts.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + create_update_parameters={ + "databaseAccountOfferType": "Standard", + "locations": [ + { + "documentEndpoint": "str", + "failoverPriority": 0, + "id": "str", + "isZoneRedundant": bool, + "locationName": "str", + "provisioningState": "str", + } + ], + "analyticalStorageConfiguration": {"schemaType": "str"}, + "apiProperties": {"serverVersion": "str"}, + "backupPolicy": "backup_policy", + "capabilities": [{"name": "str"}], + "capacity": {"totalThroughputLimit": 0}, + "connectorOffer": "str", + "consistencyPolicy": { + "defaultConsistencyLevel": "str", + "maxIntervalInSeconds": 0, + "maxStalenessPrefix": 0, + }, + "cors": [ + { + "allowedOrigins": "str", + "allowedHeaders": "str", + "allowedMethods": "str", + "exposedHeaders": "str", + "maxAgeInSeconds": 0, + } + ], + "createMode": "Default", + "customerManagedKeyStatus": "str", + "defaultIdentity": "str", + "disableKeyBasedMetadataWriteAccess": bool, + "disableLocalAuth": bool, + "enableAnalyticalStorage": bool, + "enableAutomaticFailover": bool, + "enableBurstCapacity": bool, + "enableCassandraConnector": bool, + "enableFreeTier": bool, + "enableMultipleWriteLocations": bool, + "enablePartitionMerge": bool, + "id": "str", + "identity": { + "principalId": "str", + "tenantId": "str", + "type": "str", + "userAssignedIdentities": {"str": {"clientId": "str", "principalId": "str"}}, + }, + "ipRules": [{"ipAddressOrRange": "str"}], + "isVirtualNetworkFilterEnabled": bool, + "keyVaultKeyUri": "str", + "keysMetadata": { + "primaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "primaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + }, + "kind": "str", + "location": "str", + "minimalTlsVersion": "str", + "name": "str", + "networkAclBypass": "str", + "networkAclBypassResourceIds": ["str"], + "publicNetworkAccess": "str", + "restoreParameters": { + "databasesToRestore": [{"collectionNames": ["str"], "databaseName": "str"}], + "gremlinDatabasesToRestore": [{"databaseName": "str", "graphNames": ["str"]}], + "restoreMode": "str", + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + "tablesToRestore": ["str"], + }, + "tags": {"str": "str"}, + "type": "str", + "virtualNetworkRules": [{"id": "str", "ignoreMissingVNetServiceEndpoint": bool}], + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.database_accounts.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_failover_priority_change(self, resource_group): + response = self.client.database_accounts.begin_failover_priority_change( + resource_group_name=resource_group.name, + account_name="str", + failover_parameters={"failoverPolicies": [{"failoverPriority": 0, "id": "str", "locationName": "str"}]}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.database_accounts.list( + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_resource_group(self, resource_group): + response = self.client.database_accounts.list_by_resource_group( + resource_group_name=resource_group.name, + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_keys(self, resource_group): + response = self.client.database_accounts.list_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_connection_strings(self, resource_group): + response = self.client.database_accounts.list_connection_strings( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_offline_region(self, resource_group): + response = self.client.database_accounts.begin_offline_region( + resource_group_name=resource_group.name, + account_name="str", + region_parameter_for_offline={"region": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_online_region(self, resource_group): + response = self.client.database_accounts.begin_online_region( + resource_group_name=resource_group.name, + account_name="str", + region_parameter_for_online={"region": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_read_only_keys(self, resource_group): + response = self.client.database_accounts.get_read_only_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_read_only_keys(self, resource_group): + response = self.client.database_accounts.list_read_only_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_regenerate_key(self, resource_group): + response = self.client.database_accounts.begin_regenerate_key( + resource_group_name=resource_group.name, + account_name="str", + key_to_regenerate={"keyKind": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_check_name_exists(self, resource_group): + response = self.client.database_accounts.check_name_exists( + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.database_accounts.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_usages(self, resource_group): + response = self.client.database_accounts.list_usages( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metric_definitions(self, resource_group): + response = self.client.database_accounts.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations_async.py new file mode 100644 index 00000000000..db9872a38c7 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_accounts_operations_async.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseAccountsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.database_accounts.get( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update(self, resource_group): + response = await ( + await self.client.database_accounts.begin_update( + resource_group_name=resource_group.name, + account_name="str", + update_parameters={ + "analyticalStorageConfiguration": {"schemaType": "str"}, + "apiProperties": {"serverVersion": "str"}, + "backupPolicy": "backup_policy", + "capabilities": [{"name": "str"}], + "capacity": {"totalThroughputLimit": 0}, + "connectorOffer": "str", + "consistencyPolicy": { + "defaultConsistencyLevel": "str", + "maxIntervalInSeconds": 0, + "maxStalenessPrefix": 0, + }, + "cors": [ + { + "allowedOrigins": "str", + "allowedHeaders": "str", + "allowedMethods": "str", + "exposedHeaders": "str", + "maxAgeInSeconds": 0, + } + ], + "customerManagedKeyStatus": "str", + "defaultIdentity": "str", + "disableKeyBasedMetadataWriteAccess": bool, + "disableLocalAuth": bool, + "enableAnalyticalStorage": bool, + "enableAutomaticFailover": bool, + "enableBurstCapacity": bool, + "enableCassandraConnector": bool, + "enableFreeTier": bool, + "enableMultipleWriteLocations": bool, + "enablePartitionMerge": bool, + "identity": { + "principalId": "str", + "tenantId": "str", + "type": "str", + "userAssignedIdentities": {"str": {"clientId": "str", "principalId": "str"}}, + }, + "ipRules": [{"ipAddressOrRange": "str"}], + "isVirtualNetworkFilterEnabled": bool, + "keyVaultKeyUri": "str", + "keysMetadata": { + "primaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "primaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + }, + "location": "str", + "locations": [ + { + "documentEndpoint": "str", + "failoverPriority": 0, + "id": "str", + "isZoneRedundant": bool, + "locationName": "str", + "provisioningState": "str", + } + ], + "minimalTlsVersion": "str", + "networkAclBypass": "str", + "networkAclBypassResourceIds": ["str"], + "publicNetworkAccess": "str", + "tags": {"str": "str"}, + "virtualNetworkRules": [{"id": "str", "ignoreMissingVNetServiceEndpoint": bool}], + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_or_update(self, resource_group): + response = await ( + await self.client.database_accounts.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + create_update_parameters={ + "databaseAccountOfferType": "Standard", + "locations": [ + { + "documentEndpoint": "str", + "failoverPriority": 0, + "id": "str", + "isZoneRedundant": bool, + "locationName": "str", + "provisioningState": "str", + } + ], + "analyticalStorageConfiguration": {"schemaType": "str"}, + "apiProperties": {"serverVersion": "str"}, + "backupPolicy": "backup_policy", + "capabilities": [{"name": "str"}], + "capacity": {"totalThroughputLimit": 0}, + "connectorOffer": "str", + "consistencyPolicy": { + "defaultConsistencyLevel": "str", + "maxIntervalInSeconds": 0, + "maxStalenessPrefix": 0, + }, + "cors": [ + { + "allowedOrigins": "str", + "allowedHeaders": "str", + "allowedMethods": "str", + "exposedHeaders": "str", + "maxAgeInSeconds": 0, + } + ], + "createMode": "Default", + "customerManagedKeyStatus": "str", + "defaultIdentity": "str", + "disableKeyBasedMetadataWriteAccess": bool, + "disableLocalAuth": bool, + "enableAnalyticalStorage": bool, + "enableAutomaticFailover": bool, + "enableBurstCapacity": bool, + "enableCassandraConnector": bool, + "enableFreeTier": bool, + "enableMultipleWriteLocations": bool, + "enablePartitionMerge": bool, + "id": "str", + "identity": { + "principalId": "str", + "tenantId": "str", + "type": "str", + "userAssignedIdentities": {"str": {"clientId": "str", "principalId": "str"}}, + }, + "ipRules": [{"ipAddressOrRange": "str"}], + "isVirtualNetworkFilterEnabled": bool, + "keyVaultKeyUri": "str", + "keysMetadata": { + "primaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "primaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + "secondaryReadonlyMasterKey": {"generationTime": "2020-02-20 00:00:00"}, + }, + "kind": "str", + "location": "str", + "minimalTlsVersion": "str", + "name": "str", + "networkAclBypass": "str", + "networkAclBypassResourceIds": ["str"], + "publicNetworkAccess": "str", + "restoreParameters": { + "databasesToRestore": [{"collectionNames": ["str"], "databaseName": "str"}], + "gremlinDatabasesToRestore": [{"databaseName": "str", "graphNames": ["str"]}], + "restoreMode": "str", + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + "tablesToRestore": ["str"], + }, + "tags": {"str": "str"}, + "type": "str", + "virtualNetworkRules": [{"id": "str", "ignoreMissingVNetServiceEndpoint": bool}], + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.database_accounts.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_failover_priority_change(self, resource_group): + response = await ( + await self.client.database_accounts.begin_failover_priority_change( + resource_group_name=resource_group.name, + account_name="str", + failover_parameters={"failoverPolicies": [{"failoverPriority": 0, "id": "str", "locationName": "str"}]}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.database_accounts.list( + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_resource_group(self, resource_group): + response = self.client.database_accounts.list_by_resource_group( + resource_group_name=resource_group.name, + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_keys(self, resource_group): + response = await self.client.database_accounts.list_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_connection_strings(self, resource_group): + response = await self.client.database_accounts.list_connection_strings( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_offline_region(self, resource_group): + response = await ( + await self.client.database_accounts.begin_offline_region( + resource_group_name=resource_group.name, + account_name="str", + region_parameter_for_offline={"region": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_online_region(self, resource_group): + response = await ( + await self.client.database_accounts.begin_online_region( + resource_group_name=resource_group.name, + account_name="str", + region_parameter_for_online={"region": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_read_only_keys(self, resource_group): + response = await self.client.database_accounts.get_read_only_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_read_only_keys(self, resource_group): + response = await self.client.database_accounts.list_read_only_keys( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_regenerate_key(self, resource_group): + response = await ( + await self.client.database_accounts.begin_regenerate_key( + resource_group_name=resource_group.name, + account_name="str", + key_to_regenerate={"keyKind": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_check_name_exists(self, resource_group): + response = await self.client.database_accounts.check_name_exists( + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.database_accounts.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_usages(self, resource_group): + response = self.client.database_accounts.list_usages( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metric_definitions(self, resource_group): + response = self.client.database_accounts.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations.py new file mode 100644 index 00000000000..a653a14fc0b --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.database.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_usages(self, resource_group): + response = self.client.database.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metric_definitions(self, resource_group): + response = self.client.database.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations_async.py new file mode 100644 index 00000000000..4827a58d3a6 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_database_operations_async.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementDatabaseOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.database.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_usages(self, resource_group): + response = self.client.database.list_usages( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metric_definitions(self, resource_group): + response = self.client.database.list_metric_definitions( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations.py new file mode 100644 index 00000000000..32146e56e17 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementGremlinResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_gremlin_databases(self, resource_group): + response = self.client.gremlin_resources.list_gremlin_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_gremlin_database(self, resource_group): + response = self.client.gremlin_resources.get_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_gremlin_database(self, resource_group): + response = self.client.gremlin_resources.begin_create_update_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_gremlin_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_gremlin_database(self, resource_group): + response = self.client.gremlin_resources.begin_delete_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_gremlin_database_throughput(self, resource_group): + response = self.client.gremlin_resources.get_gremlin_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_gremlin_database_throughput(self, resource_group): + response = self.client.gremlin_resources.begin_update_gremlin_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_gremlin_database_to_autoscale(self, resource_group): + response = self.client.gremlin_resources.begin_migrate_gremlin_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_gremlin_database_to_manual_throughput(self, resource_group): + response = self.client.gremlin_resources.begin_migrate_gremlin_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_gremlin_graphs(self, resource_group): + response = self.client.gremlin_resources.list_gremlin_graphs( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_gremlin_graph(self, resource_group): + response = self.client.gremlin_resources.get_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_gremlin_graph(self, resource_group): + response = self.client.gremlin_resources.begin_create_update_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + create_update_gremlin_graph_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "conflictResolutionPolicy": { + "conflictResolutionPath": "str", + "conflictResolutionProcedure": "str", + "mode": "LastWriterWins", + }, + "createMode": "Default", + "defaultTtl": 0, + "indexingPolicy": { + "automatic": bool, + "compositeIndexes": [[{"order": "str", "path": "str"}]], + "excludedPaths": [{"path": "str"}], + "includedPaths": [ + {"indexes": [{"dataType": "String", "kind": "Hash", "precision": 0}], "path": "str"} + ], + "indexingMode": "consistent", + "spatialIndexes": [{"path": "str", "types": ["str"]}], + }, + "partitionKey": {"kind": "Hash", "paths": ["str"], "systemKey": bool, "version": 0}, + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "uniqueKeyPolicy": {"uniqueKeys": [{"paths": ["str"]}]}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_gremlin_graph(self, resource_group): + response = self.client.gremlin_resources.begin_delete_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_gremlin_graph_throughput(self, resource_group): + response = self.client.gremlin_resources.get_gremlin_graph_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_gremlin_graph_throughput(self, resource_group): + response = self.client.gremlin_resources.begin_update_gremlin_graph_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_gremlin_graph_to_autoscale(self, resource_group): + response = self.client.gremlin_resources.begin_migrate_gremlin_graph_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_gremlin_graph_to_manual_throughput(self, resource_group): + response = self.client.gremlin_resources.begin_migrate_gremlin_graph_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = self.client.gremlin_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations_async.py new file mode 100644 index 00000000000..4015eb3dfe5 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_gremlin_resources_operations_async.py @@ -0,0 +1,362 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementGremlinResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_gremlin_databases(self, resource_group): + response = self.client.gremlin_resources.list_gremlin_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_gremlin_database(self, resource_group): + response = await self.client.gremlin_resources.get_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_gremlin_database(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_create_update_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_gremlin_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_gremlin_database(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_delete_gremlin_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_gremlin_database_throughput(self, resource_group): + response = await self.client.gremlin_resources.get_gremlin_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_gremlin_database_throughput(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_update_gremlin_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_gremlin_database_to_autoscale(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_migrate_gremlin_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_gremlin_database_to_manual_throughput(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_migrate_gremlin_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_gremlin_graphs(self, resource_group): + response = self.client.gremlin_resources.list_gremlin_graphs( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_gremlin_graph(self, resource_group): + response = await self.client.gremlin_resources.get_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_gremlin_graph(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_create_update_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + create_update_gremlin_graph_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "conflictResolutionPolicy": { + "conflictResolutionPath": "str", + "conflictResolutionProcedure": "str", + "mode": "LastWriterWins", + }, + "createMode": "Default", + "defaultTtl": 0, + "indexingPolicy": { + "automatic": bool, + "compositeIndexes": [[{"order": "str", "path": "str"}]], + "excludedPaths": [{"path": "str"}], + "includedPaths": [ + {"indexes": [{"dataType": "String", "kind": "Hash", "precision": 0}], "path": "str"} + ], + "indexingMode": "consistent", + "spatialIndexes": [{"path": "str", "types": ["str"]}], + }, + "partitionKey": {"kind": "Hash", "paths": ["str"], "systemKey": bool, "version": 0}, + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "uniqueKeyPolicy": {"uniqueKeys": [{"paths": ["str"]}]}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_gremlin_graph(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_delete_gremlin_graph( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_gremlin_graph_throughput(self, resource_group): + response = await self.client.gremlin_resources.get_gremlin_graph_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_gremlin_graph_throughput(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_update_gremlin_graph_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_gremlin_graph_to_autoscale(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_migrate_gremlin_graph_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_gremlin_graph_to_manual_throughput(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_migrate_gremlin_graph_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = await ( + await self.client.gremlin_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + graph_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations.py new file mode 100644 index 00000000000..553979647f9 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementLocationsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.locations.list( + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.locations.get( + location="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations_async.py new file mode 100644 index 00000000000..07fe77be07f --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_locations_operations_async.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementLocationsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.locations.list( + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.locations.get( + location="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations.py new file mode 100644 index 00000000000..1a1ce74f949 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations.py @@ -0,0 +1,440 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementMongoDBResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_mongo_db_databases(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_db_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_db_database(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_mongo_db_database(self, resource_group): + response = self.client.mongo_db_resources.begin_create_update_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_mongo_db_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_mongo_db_database(self, resource_group): + response = self.client.mongo_db_resources.begin_delete_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_db_database_throughput(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_db_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_mongo_db_database_throughput(self, resource_group): + response = self.client.mongo_db_resources.begin_update_mongo_db_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_mongo_db_database_to_autoscale(self, resource_group): + response = self.client.mongo_db_resources.begin_migrate_mongo_db_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_mongo_db_database_to_manual_throughput(self, resource_group): + response = self.client.mongo_db_resources.begin_migrate_mongo_db_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_mongo_db_collections(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_db_collections( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_db_collection(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_mongo_db_collection(self, resource_group): + response = self.client.mongo_db_resources.begin_create_update_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + create_update_mongo_db_collection_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "createMode": "Default", + "indexes": [{"key": {"keys": ["str"]}, "options": {"expireAfterSeconds": 0, "unique": bool}}], + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "shardKey": {"str": "str"}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_mongo_db_collection(self, resource_group): + response = self.client.mongo_db_resources.begin_delete_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_db_collection_throughput(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_db_collection_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_mongo_db_collection_throughput(self, resource_group): + response = self.client.mongo_db_resources.begin_update_mongo_db_collection_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_mongo_db_collection_to_autoscale(self, resource_group): + response = self.client.mongo_db_resources.begin_migrate_mongo_db_collection_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_mongo_db_collection_to_manual_throughput(self, resource_group): + response = self.client.mongo_db_resources.begin_migrate_mongo_db_collection_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_role_definition(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_mongo_role_definition(self, resource_group): + response = self.client.mongo_db_resources.begin_create_update_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_mongo_role_definition_parameters={ + "databaseName": "str", + "privileges": [{"actions": ["str"], "resource": {"collection": "str", "db": "str"}}], + "roleName": "str", + "roles": [{"db": "str", "role": "str"}], + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_mongo_role_definition(self, resource_group): + response = self.client.mongo_db_resources.begin_delete_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_mongo_role_definitions(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_role_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_mongo_user_definition(self, resource_group): + response = self.client.mongo_db_resources.get_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_mongo_user_definition(self, resource_group): + response = self.client.mongo_db_resources.begin_create_update_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_mongo_user_definition_parameters={ + "customData": "str", + "databaseName": "str", + "mechanisms": "str", + "password": "str", + "roles": [{"db": "str", "role": "str"}], + "userName": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_mongo_user_definition(self, resource_group): + response = self.client.mongo_db_resources.begin_delete_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_mongo_user_definitions(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_user_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = self.client.mongo_db_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations_async.py new file mode 100644 index 00000000000..ecc8700fe85 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_mongo_db_resources_operations_async.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementMongoDBResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_mongo_db_databases(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_db_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_db_database(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_mongo_db_database(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_create_update_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_mongo_db_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_mongo_db_database(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_delete_mongo_db_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_db_database_throughput(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_db_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_mongo_db_database_throughput(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_update_mongo_db_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_mongo_db_database_to_autoscale(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_migrate_mongo_db_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_mongo_db_database_to_manual_throughput(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_migrate_mongo_db_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_mongo_db_collections(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_db_collections( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_db_collection(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_mongo_db_collection(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_create_update_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + create_update_mongo_db_collection_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "createMode": "Default", + "indexes": [{"key": {"keys": ["str"]}, "options": {"expireAfterSeconds": 0, "unique": bool}}], + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "shardKey": {"str": "str"}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_mongo_db_collection(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_delete_mongo_db_collection( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_db_collection_throughput(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_db_collection_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_mongo_db_collection_throughput(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_update_mongo_db_collection_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_mongo_db_collection_to_autoscale(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_migrate_mongo_db_collection_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_mongo_db_collection_to_manual_throughput(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_migrate_mongo_db_collection_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_role_definition(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_mongo_role_definition(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_create_update_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_mongo_role_definition_parameters={ + "databaseName": "str", + "privileges": [{"actions": ["str"], "resource": {"collection": "str", "db": "str"}}], + "roleName": "str", + "roles": [{"db": "str", "role": "str"}], + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_mongo_role_definition(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_delete_mongo_role_definition( + mongo_role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_mongo_role_definitions(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_role_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_mongo_user_definition(self, resource_group): + response = await self.client.mongo_db_resources.get_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_mongo_user_definition(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_create_update_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_mongo_user_definition_parameters={ + "customData": "str", + "databaseName": "str", + "mechanisms": "str", + "password": "str", + "roles": [{"db": "str", "role": "str"}], + "userName": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_mongo_user_definition(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_delete_mongo_user_definition( + mongo_user_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_mongo_user_definitions(self, resource_group): + response = self.client.mongo_db_resources.list_mongo_user_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = await ( + await self.client.mongo_db_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + collection_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations.py new file mode 100644 index 00000000000..0454f2e5a93 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementNotebookWorkspacesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_database_account(self, resource_group): + response = self.client.notebook_workspaces.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.notebook_workspaces.get( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_or_update(self, resource_group): + response = self.client.notebook_workspaces.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + notebook_create_update_parameters={"id": "str", "name": "str", "type": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.notebook_workspaces.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_connection_info(self, resource_group): + response = self.client.notebook_workspaces.list_connection_info( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_regenerate_auth_token(self, resource_group): + response = self.client.notebook_workspaces.begin_regenerate_auth_token( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_start(self, resource_group): + response = self.client.notebook_workspaces.begin_start( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations_async.py new file mode 100644 index 00000000000..d6d87b993fc --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_notebook_workspaces_operations_async.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementNotebookWorkspacesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_database_account(self, resource_group): + response = self.client.notebook_workspaces.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.notebook_workspaces.get( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_or_update(self, resource_group): + response = await ( + await self.client.notebook_workspaces.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + notebook_create_update_parameters={"id": "str", "name": "str", "type": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.notebook_workspaces.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_connection_info(self, resource_group): + response = await self.client.notebook_workspaces.list_connection_info( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_regenerate_auth_token(self, resource_group): + response = await ( + await self.client.notebook_workspaces.begin_regenerate_auth_token( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_start(self, resource_group): + response = await ( + await self.client.notebook_workspaces.begin_start( + resource_group_name=resource_group.name, + account_name="str", + notebook_workspace_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations.py new file mode 100644 index 00000000000..4097831969d --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.operations.list( + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations_async.py new file mode 100644 index 00000000000..a90f811a5df --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_operations_async.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.operations.list( + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations.py new file mode 100644 index 00000000000..02fb73a2bc5 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPartitionKeyRangeIdOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.partition_key_range_id.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + partition_key_range_id="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations_async.py new file mode 100644 index 00000000000..690749c6db6 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_operations_async.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPartitionKeyRangeIdOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.partition_key_range_id.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + database_rid="str", + collection_rid="str", + partition_key_range_id="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations.py new file mode 100644 index 00000000000..e527d1ad28b --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPartitionKeyRangeIdRegionOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.partition_key_range_id_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + partition_key_range_id="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations_async.py new file mode 100644 index 00000000000..aecd53e18f6 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_partition_key_range_id_region_operations_async.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPartitionKeyRangeIdRegionOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.partition_key_range_id_region.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + region="str", + database_rid="str", + collection_rid="str", + partition_key_range_id="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations.py new file mode 100644 index 00000000000..1ea1133871a --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.percentile.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations_async.py new file mode 100644 index 00000000000..660bed633ac --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_operations_async.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.percentile.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations.py new file mode 100644 index 00000000000..2820b4a6c83 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileSourceTargetOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.percentile_source_target.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + source_region="str", + target_region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations_async.py new file mode 100644 index 00000000000..f9a0ee9964c --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_source_target_operations_async.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileSourceTargetOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.percentile_source_target.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + source_region="str", + target_region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations.py new file mode 100644 index 00000000000..3562dadeaee --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileTargetOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_metrics(self, resource_group): + response = self.client.percentile_target.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + target_region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations_async.py new file mode 100644 index 00000000000..91e96ca2ee4 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_percentile_target_operations_async.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPercentileTargetOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_metrics(self, resource_group): + response = self.client.percentile_target.list_metrics( + resource_group_name=resource_group.name, + account_name="str", + target_region="str", + filter="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations.py new file mode 100644 index 00000000000..b25794ad024 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPrivateEndpointConnectionsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_database_account(self, resource_group): + response = self.client.private_endpoint_connections.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.private_endpoint_connections.get( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_or_update(self, resource_group): + response = self.client.private_endpoint_connections.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + parameters={ + "groupId": "str", + "id": "str", + "name": "str", + "privateEndpoint": {"id": "str"}, + "privateLinkServiceConnectionState": {"actionsRequired": "str", "description": "str", "status": "str"}, + "provisioningState": "str", + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.private_endpoint_connections.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations_async.py new file mode 100644 index 00000000000..87366a60f7b --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_endpoint_connections_operations_async.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPrivateEndpointConnectionsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_database_account(self, resource_group): + response = self.client.private_endpoint_connections.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.private_endpoint_connections.get( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_or_update(self, resource_group): + response = await ( + await self.client.private_endpoint_connections.begin_create_or_update( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + parameters={ + "groupId": "str", + "id": "str", + "name": "str", + "privateEndpoint": {"id": "str"}, + "privateLinkServiceConnectionState": { + "actionsRequired": "str", + "description": "str", + "status": "str", + }, + "provisioningState": "str", + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.private_endpoint_connections.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + private_endpoint_connection_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations.py new file mode 100644 index 00000000000..e7cdf9f4793 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPrivateLinkResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_database_account(self, resource_group): + response = self.client.private_link_resources.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.private_link_resources.get( + resource_group_name=resource_group.name, + account_name="str", + group_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations_async.py new file mode 100644 index 00000000000..93b9a0864e3 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_private_link_resources_operations_async.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementPrivateLinkResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_database_account(self, resource_group): + response = self.client.private_link_resources.list_by_database_account( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.private_link_resources.get( + resource_group_name=resource_group.name, + account_name="str", + group_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations.py new file mode 100644 index 00000000000..d14381d3655 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableDatabaseAccountsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_location(self, resource_group): + response = self.client.restorable_database_accounts.list_by_location( + location="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_database_accounts.list( + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_by_location(self, resource_group): + response = self.client.restorable_database_accounts.get_by_location( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations_async.py new file mode 100644 index 00000000000..c4b673a42dd --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_database_accounts_operations_async.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableDatabaseAccountsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_location(self, resource_group): + response = self.client.restorable_database_accounts.list_by_location( + location="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_database_accounts.list( + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_by_location(self, resource_group): + response = await self.client.restorable_database_accounts.get_by_location( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations.py new file mode 100644 index 00000000000..7e3f70e5da3 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinDatabasesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_gremlin_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations_async.py new file mode 100644 index 00000000000..2906af86599 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_databases_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinDatabasesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_gremlin_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations.py new file mode 100644 index 00000000000..8c68b09df59 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinGraphsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_gremlin_graphs.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations_async.py new file mode 100644 index 00000000000..8635bee2e18 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_graphs_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinGraphsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_gremlin_graphs.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations.py new file mode 100644 index 00000000000..2442d0ec6f4 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_gremlin_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations_async.py new file mode 100644 index 00000000000..9377c87873f --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_gremlin_resources_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableGremlinResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_gremlin_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations.py new file mode 100644 index 00000000000..8e73eacd5e6 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbCollectionsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_mongodb_collections.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations_async.py new file mode 100644 index 00000000000..99d5e223928 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_collections_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbCollectionsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_mongodb_collections.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations.py new file mode 100644 index 00000000000..56cedccada2 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbDatabasesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_mongodb_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations_async.py new file mode 100644 index 00000000000..ca71ba14853 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_databases_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbDatabasesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_mongodb_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations.py new file mode 100644 index 00000000000..85a12dd3a36 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_mongodb_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations_async.py new file mode 100644 index 00000000000..6c3576ca155 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_mongodb_resources_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableMongodbResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_mongodb_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations.py new file mode 100644 index 00000000000..f265cc9238f --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlContainersOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_sql_containers.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations_async.py new file mode 100644 index 00000000000..3868e02a4cc --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_containers_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlContainersOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_sql_containers.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations.py new file mode 100644 index 00000000000..4708d00a393 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlDatabasesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_sql_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations_async.py new file mode 100644 index 00000000000..4f959dd3ff5 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_databases_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlDatabasesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_sql_databases.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations.py new file mode 100644 index 00000000000..e08b885ba42 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_sql_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations_async.py new file mode 100644 index 00000000000..a156fc2237a --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_sql_resources_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableSqlResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_sql_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations.py new file mode 100644 index 00000000000..f02853bfe1d --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableTableResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_table_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations_async.py new file mode 100644 index 00000000000..c8ef6faa9ef --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_table_resources_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableTableResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_table_resources.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations.py new file mode 100644 index 00000000000..b0b74c9db98 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableTablesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.restorable_tables.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations_async.py new file mode 100644 index 00000000000..52f96bb058a --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_restorable_tables_operations_async.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementRestorableTablesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.restorable_tables.list( + location="str", + instance_id="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations.py new file mode 100644 index 00000000000..de14b67e43f --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementServiceOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.service.list( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create(self, resource_group): + response = self.client.service.begin_create( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + create_update_parameters={"properties": "service_resource_create_update_properties"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get(self, resource_group): + response = self.client.service.get( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete(self, resource_group): + response = self.client.service.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations_async.py new file mode 100644 index 00000000000..9e68d98f40c --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_service_operations_async.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementServiceOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.service.list( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create(self, resource_group): + response = await ( + await self.client.service.begin_create( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + create_update_parameters={"properties": "service_resource_create_update_properties"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get(self, resource_group): + response = await self.client.service.get( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete(self, resource_group): + response = await ( + await self.client.service.begin_delete( + resource_group_name=resource_group.name, + account_name="str", + service_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations.py new file mode 100644 index 00000000000..69b4962f52a --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations.py @@ -0,0 +1,717 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementSqlResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_databases(self, resource_group): + response = self.client.sql_resources.list_sql_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_database(self, resource_group): + response = self.client.sql_resources.get_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_database(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_sql_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_database(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_database_throughput(self, resource_group): + response = self.client.sql_resources.get_sql_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_sql_database_throughput(self, resource_group): + response = self.client.sql_resources.begin_update_sql_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_sql_database_to_autoscale(self, resource_group): + response = self.client.sql_resources.begin_migrate_sql_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_sql_database_to_manual_throughput(self, resource_group): + response = self.client.sql_resources.begin_migrate_sql_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_containers(self, resource_group): + response = self.client.sql_resources.list_sql_containers( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_container(self, resource_group): + response = self.client.sql_resources.get_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_container(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + create_update_sql_container_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "clientEncryptionPolicy": { + "includedPaths": [ + { + "clientEncryptionKeyId": "str", + "encryptionAlgorithm": "str", + "encryptionType": "str", + "path": "str", + } + ], + "policyFormatVersion": 0, + }, + "computedProperties": [{"name": "str", "query": "str"}], + "conflictResolutionPolicy": { + "conflictResolutionPath": "str", + "conflictResolutionProcedure": "str", + "mode": "LastWriterWins", + }, + "createMode": "Default", + "defaultTtl": 0, + "indexingPolicy": { + "automatic": bool, + "compositeIndexes": [[{"order": "str", "path": "str"}]], + "excludedPaths": [{"path": "str"}], + "includedPaths": [ + {"indexes": [{"dataType": "String", "kind": "Hash", "precision": 0}], "path": "str"} + ], + "indexingMode": "consistent", + "spatialIndexes": [{"path": "str", "types": ["str"]}], + }, + "partitionKey": {"kind": "Hash", "paths": ["str"], "systemKey": bool, "version": 0}, + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "uniqueKeyPolicy": {"uniqueKeys": [{"paths": ["str"]}]}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_container(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_container_throughput(self, resource_group): + response = self.client.sql_resources.get_sql_container_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_sql_container_throughput(self, resource_group): + response = self.client.sql_resources.begin_update_sql_container_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_sql_container_to_autoscale(self, resource_group): + response = self.client.sql_resources.begin_migrate_sql_container_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_sql_container_to_manual_throughput(self, resource_group): + response = self.client.sql_resources.begin_migrate_sql_container_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_client_encryption_keys(self, resource_group): + response = self.client.sql_resources.list_client_encryption_keys( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_client_encryption_key(self, resource_group): + response = self.client.sql_resources.get_client_encryption_key( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + client_encryption_key_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_client_encryption_key(self, resource_group): + response = self.client.sql_resources.begin_create_update_client_encryption_key( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + client_encryption_key_name="str", + create_update_client_encryption_key_parameters={ + "resource": { + "encryptionAlgorithm": "str", + "id": "str", + "keyWrapMetadata": {"algorithm": "str", "name": "str", "type": "str", "value": "str"}, + "wrappedDataEncryptionKey": bytes("bytes", encoding="utf-8"), + } + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_stored_procedures(self, resource_group): + response = self.client.sql_resources.list_sql_stored_procedures( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_stored_procedure(self, resource_group): + response = self.client.sql_resources.get_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_stored_procedure(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + create_update_sql_stored_procedure_parameters={ + "resource": {"id": "str", "body": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_stored_procedure(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_user_defined_functions(self, resource_group): + response = self.client.sql_resources.list_sql_user_defined_functions( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_user_defined_function(self, resource_group): + response = self.client.sql_resources.get_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_user_defined_function(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + create_update_sql_user_defined_function_parameters={ + "resource": {"id": "str", "body": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_user_defined_function(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_triggers(self, resource_group): + response = self.client.sql_resources.list_sql_triggers( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_trigger(self, resource_group): + response = self.client.sql_resources.get_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_trigger(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + create_update_sql_trigger_parameters={ + "resource": {"id": "str", "body": "str", "triggerOperation": "str", "triggerType": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_trigger(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_role_definition(self, resource_group): + response = self.client.sql_resources.get_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_role_definition(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_sql_role_definition_parameters={ + "assignableScopes": ["str"], + "permissions": [{"dataActions": ["str"], "notDataActions": ["str"]}], + "roleName": "str", + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_role_definition(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_role_definitions(self, resource_group): + response = self.client.sql_resources.list_sql_role_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_sql_role_assignment(self, resource_group): + response = self.client.sql_resources.get_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_sql_role_assignment(self, resource_group): + response = self.client.sql_resources.begin_create_update_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_sql_role_assignment_parameters={ + "principalId": "str", + "roleDefinitionId": "str", + "scope": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_sql_role_assignment(self, resource_group): + response = self.client.sql_resources.begin_delete_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_sql_role_assignments(self, resource_group): + response = self.client.sql_resources.list_sql_role_assignments( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = self.client.sql_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations_async.py new file mode 100644 index 00000000000..179c071b2b3 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_sql_resources_operations_async.py @@ -0,0 +1,762 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementSqlResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_databases(self, resource_group): + response = self.client.sql_resources.list_sql_databases( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_database(self, resource_group): + response = await self.client.sql_resources.get_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_database(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + create_update_sql_database_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_database(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_database( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_database_throughput(self, resource_group): + response = await self.client.sql_resources.get_sql_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_sql_database_throughput(self, resource_group): + response = await ( + await self.client.sql_resources.begin_update_sql_database_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_sql_database_to_autoscale(self, resource_group): + response = await ( + await self.client.sql_resources.begin_migrate_sql_database_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_sql_database_to_manual_throughput(self, resource_group): + response = await ( + await self.client.sql_resources.begin_migrate_sql_database_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_containers(self, resource_group): + response = self.client.sql_resources.list_sql_containers( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_container(self, resource_group): + response = await self.client.sql_resources.get_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_container(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + create_update_sql_container_parameters={ + "resource": { + "id": "str", + "analyticalStorageTtl": 0, + "clientEncryptionPolicy": { + "includedPaths": [ + { + "clientEncryptionKeyId": "str", + "encryptionAlgorithm": "str", + "encryptionType": "str", + "path": "str", + } + ], + "policyFormatVersion": 0, + }, + "computedProperties": [{"name": "str", "query": "str"}], + "conflictResolutionPolicy": { + "conflictResolutionPath": "str", + "conflictResolutionProcedure": "str", + "mode": "LastWriterWins", + }, + "createMode": "Default", + "defaultTtl": 0, + "indexingPolicy": { + "automatic": bool, + "compositeIndexes": [[{"order": "str", "path": "str"}]], + "excludedPaths": [{"path": "str"}], + "includedPaths": [ + {"indexes": [{"dataType": "String", "kind": "Hash", "precision": 0}], "path": "str"} + ], + "indexingMode": "consistent", + "spatialIndexes": [{"path": "str", "types": ["str"]}], + }, + "partitionKey": {"kind": "Hash", "paths": ["str"], "systemKey": bool, "version": 0}, + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + "uniqueKeyPolicy": {"uniqueKeys": [{"paths": ["str"]}]}, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_container(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_container( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_container_throughput(self, resource_group): + response = await self.client.sql_resources.get_sql_container_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_sql_container_throughput(self, resource_group): + response = await ( + await self.client.sql_resources.begin_update_sql_container_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_sql_container_to_autoscale(self, resource_group): + response = await ( + await self.client.sql_resources.begin_migrate_sql_container_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_sql_container_to_manual_throughput(self, resource_group): + response = await ( + await self.client.sql_resources.begin_migrate_sql_container_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_client_encryption_keys(self, resource_group): + response = self.client.sql_resources.list_client_encryption_keys( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_client_encryption_key(self, resource_group): + response = await self.client.sql_resources.get_client_encryption_key( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + client_encryption_key_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_client_encryption_key(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_client_encryption_key( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + client_encryption_key_name="str", + create_update_client_encryption_key_parameters={ + "resource": { + "encryptionAlgorithm": "str", + "id": "str", + "keyWrapMetadata": {"algorithm": "str", "name": "str", "type": "str", "value": "str"}, + "wrappedDataEncryptionKey": bytes("bytes", encoding="utf-8"), + } + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_stored_procedures(self, resource_group): + response = self.client.sql_resources.list_sql_stored_procedures( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_stored_procedure(self, resource_group): + response = await self.client.sql_resources.get_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_stored_procedure(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + create_update_sql_stored_procedure_parameters={ + "resource": {"id": "str", "body": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_stored_procedure(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_stored_procedure( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + stored_procedure_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_user_defined_functions(self, resource_group): + response = self.client.sql_resources.list_sql_user_defined_functions( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_user_defined_function(self, resource_group): + response = await self.client.sql_resources.get_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_user_defined_function(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + create_update_sql_user_defined_function_parameters={ + "resource": {"id": "str", "body": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_user_defined_function(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_user_defined_function( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + user_defined_function_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_triggers(self, resource_group): + response = self.client.sql_resources.list_sql_triggers( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_trigger(self, resource_group): + response = await self.client.sql_resources.get_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_trigger(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + create_update_sql_trigger_parameters={ + "resource": {"id": "str", "body": "str", "triggerOperation": "str", "triggerType": "str"}, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_trigger(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_trigger( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + trigger_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_role_definition(self, resource_group): + response = await self.client.sql_resources.get_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_role_definition(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_sql_role_definition_parameters={ + "assignableScopes": ["str"], + "permissions": [{"dataActions": ["str"], "notDataActions": ["str"]}], + "roleName": "str", + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_role_definition(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_role_definition( + role_definition_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_role_definitions(self, resource_group): + response = self.client.sql_resources.list_sql_role_definitions( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_sql_role_assignment(self, resource_group): + response = await self.client.sql_resources.get_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_sql_role_assignment(self, resource_group): + response = await ( + await self.client.sql_resources.begin_create_update_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + create_update_sql_role_assignment_parameters={ + "principalId": "str", + "roleDefinitionId": "str", + "scope": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_sql_role_assignment(self, resource_group): + response = await ( + await self.client.sql_resources.begin_delete_sql_role_assignment( + role_assignment_id="str", + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_sql_role_assignments(self, resource_group): + response = self.client.sql_resources.list_sql_role_assignments( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = await ( + await self.client.sql_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + database_name="str", + container_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations.py new file mode 100644 index 00000000000..a7a85f0807e --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementTableResourcesOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_tables(self, resource_group): + response = self.client.table_resources.list_tables( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_table(self, resource_group): + response = self.client.table_resources.get_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_create_update_table(self, resource_group): + response = self.client.table_resources.begin_create_update_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + create_update_table_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_delete_table(self, resource_group): + response = self.client.table_resources.begin_delete_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_get_table_throughput(self, resource_group): + response = self.client.table_resources.get_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_update_table_throughput(self, resource_group): + response = self.client.table_resources.begin_update_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_table_to_autoscale(self, resource_group): + response = self.client.table_resources.begin_migrate_table_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_migrate_table_to_manual_throughput(self, resource_group): + response = self.client.table_resources.begin_migrate_table_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = self.client.table_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations_async.py b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations_async.py new file mode 100644 index 00000000000..abbc8c1a648 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/generated_tests/test_cosmos_db_management_table_resources_operations_async.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.skip("you may need to update the auto-generated test case before run it") +class TestCosmosDBManagementTableResourcesOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_tables(self, resource_group): + response = self.client.table_resources.list_tables( + resource_group_name=resource_group.name, + account_name="str", + api_version="2024-08-15", + ) + result = [r async for r in response] + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_table(self, resource_group): + response = await self.client.table_resources.get_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_create_update_table(self, resource_group): + response = await ( + await self.client.table_resources.begin_create_update_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + create_update_table_parameters={ + "resource": { + "id": "str", + "createMode": "Default", + "restoreParameters": { + "restoreSource": "str", + "restoreTimestampInUtc": "2020-02-20 00:00:00", + "restoreWithTtlDisabled": bool, + }, + }, + "id": "str", + "location": "str", + "name": "str", + "options": {"autoscaleSettings": {"maxThroughput": 0}, "throughput": 0}, + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_delete_table(self, resource_group): + response = await ( + await self.client.table_resources.begin_delete_table( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_get_table_throughput(self, resource_group): + response = await self.client.table_resources.get_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_update_table_throughput(self, resource_group): + response = await ( + await self.client.table_resources.begin_update_table_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + update_throughput_parameters={ + "resource": { + "autoscaleSettings": { + "maxThroughput": 0, + "autoUpgradePolicy": {"throughputPolicy": {"incrementPercent": 0, "isEnabled": bool}}, + "targetMaxThroughput": 0, + }, + "instantMaximumThroughput": "str", + "minimumThroughput": "str", + "offerReplacePending": "str", + "softAllowedMaximumThroughput": "str", + "throughput": 0, + }, + "id": "str", + "location": "str", + "name": "str", + "tags": {"str": "str"}, + "type": "str", + }, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_table_to_autoscale(self, resource_group): + response = await ( + await self.client.table_resources.begin_migrate_table_to_autoscale( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_migrate_table_to_manual_throughput(self, resource_group): + response = await ( + await self.client.table_resources.begin_migrate_table_to_manual_throughput( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_begin_retrieve_continuous_backup_information(self, resource_group): + response = await ( + await self.client.table_resources.begin_retrieve_continuous_backup_information( + resource_group_name=resource_group.name, + account_name="str", + table_name="str", + location={"location": "str"}, + api_version="2024-08-15", + ) + ).result() # call '.result()' to poll until service return final result + + # please add some check logic here by yourself + # ... diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/setup.py b/sdk/cosmos/azure-mgmt-cosmosdb/setup.py index 7871875933c..04300d8bd29 100644 --- a/sdk/cosmos/azure-mgmt-cosmosdb/setup.py +++ b/sdk/cosmos/azure-mgmt-cosmosdb/setup.py @@ -75,6 +75,7 @@ setup( }, install_requires=[ "isodate>=0.6.1", + "typing-extensions>=4.6.0", "azure-common>=1.1", "azure-mgmt-core>=1.3.2", ], diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/tests/conftest.py b/sdk/cosmos/azure-mgmt-cosmosdb/tests/conftest.py new file mode 100644 index 00000000000..c6d1ee70d05 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/tests/conftest.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import os +import pytest +from dotenv import load_dotenv +from devtools_testutils import ( + test_proxy, + add_general_regex_sanitizer, + add_body_key_sanitizer, + add_header_regex_sanitizer, +) + +load_dotenv() + + +# aovid record sensitive identity information in recordings +@pytest.fixture(scope="session", autouse=True) +def add_sanitizers(test_proxy): + cosmosdbmanagement_subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_tenant_id = os.environ.get("AZURE_TENANT_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_client_id = os.environ.get("AZURE_CLIENT_ID", "00000000-0000-0000-0000-000000000000") + cosmosdbmanagement_client_secret = os.environ.get("AZURE_CLIENT_SECRET", "00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_subscription_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_tenant_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_client_id, value="00000000-0000-0000-0000-000000000000") + add_general_regex_sanitizer(regex=cosmosdbmanagement_client_secret, value="00000000-0000-0000-0000-000000000000") + + add_header_regex_sanitizer(key="Set-Cookie", value="[set-cookie;]") + add_header_regex_sanitizer(key="Cookie", value="cookie;") + add_body_key_sanitizer(json_path="$..access_token", value="access_token") diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_async_test.py b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_async_test.py new file mode 100644 index 00000000000..982ea83e375 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_async_test.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.live_test_only +class TestCosmosDBManagementDatabaseAccountsOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list_by_resource_group(self, resource_group): + response = self.client.database_accounts.list_by_resource_group( + resource_group_name=resource_group.name, + ) + result = [r async for r in response] + assert result == [] diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_test.py b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_test.py new file mode 100644 index 00000000000..abf31d4ef2c --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_database_accounts_operations_test.py @@ -0,0 +1,28 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.live_test_only +class TestCosmosDBManagementDatabaseAccountsOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list_by_resource_group(self, resource_group): + response = self.client.database_accounts.list_by_resource_group( + resource_group_name=resource_group.name, + ) + result = [r for r in response] + assert result == [] diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_async_test.py b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_async_test.py new file mode 100644 index 00000000000..6de5a28e4a8 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_async_test.py @@ -0,0 +1,28 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb.aio import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer +from devtools_testutils.aio import recorded_by_proxy_async + +AZURE_LOCATION = "eastus" + + +@pytest.mark.live_test_only +class TestCosmosDBManagementOperationsAsync(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient, is_async=True) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy_async + async def test_list(self, resource_group): + response = self.client.operations.list() + result = [r async for r in response] + assert result + diff --git a/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_test.py b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_test.py new file mode 100644 index 00000000000..03bbf3098d9 --- /dev/null +++ b/sdk/cosmos/azure-mgmt-cosmosdb/tests/test_cosmos_db_management_operations_test.py @@ -0,0 +1,27 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) AutoRest Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import pytest +from azure.mgmt.cosmosdb import CosmosDBManagementClient + +from devtools_testutils import AzureMgmtRecordedTestCase, RandomNameResourceGroupPreparer, recorded_by_proxy + +AZURE_LOCATION = "eastus" + + +@pytest.mark.live_test_only +class TestCosmosDBManagementOperations(AzureMgmtRecordedTestCase): + def setup_method(self, method): + self.client = self.create_mgmt_client(CosmosDBManagementClient) + + @RandomNameResourceGroupPreparer(location=AZURE_LOCATION) + @recorded_by_proxy + def test_list(self, resource_group): + response = self.client.operations.list() + result = [r for r in response] + assert result + From e2bd5f1ddba6e1eacd89eba9f7c17f22e80e861b Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:38:51 -0700 Subject: [PATCH 11/17] Update installed event processor version for net6 to net8 update (#37443) Co-authored-by: James Suplizio --- .github/workflows/event-processor.yml | 4 ++-- .github/workflows/scheduled-event-processor.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/event-processor.yml b/.github/workflows/event-processor.yml index 23c074b547b..907c0d5a6e9 100644 --- a/.github/workflows/event-processor.yml +++ b/.github/workflows/event-processor.yml @@ -58,7 +58,7 @@ jobs: run: > dotnet tool install Azure.Sdk.Tools.GitHubEventProcessor - --version 1.0.0-dev.20240909.2 + --version 1.0.0-dev.20240917.2 --add-source https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json --global shell: bash @@ -114,7 +114,7 @@ jobs: run: > dotnet tool install Azure.Sdk.Tools.GitHubEventProcessor - --version 1.0.0-dev.20240909.2 + --version 1.0.0-dev.20240917.2 --add-source https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json --global shell: bash diff --git a/.github/workflows/scheduled-event-processor.yml b/.github/workflows/scheduled-event-processor.yml index ce312eab41a..4b5f1132211 100644 --- a/.github/workflows/scheduled-event-processor.yml +++ b/.github/workflows/scheduled-event-processor.yml @@ -39,7 +39,7 @@ jobs: run: > dotnet tool install Azure.Sdk.Tools.GitHubEventProcessor - --version 1.0.0-dev.20240909.2 + --version 1.0.0-dev.20240917.2 --add-source https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json --global shell: bash From 203a6cc1d0763c1cf46ee936333ac53e8821f788 Mon Sep 17 00:00:00 2001 From: Neehar Duvvuri <40341266+needuv@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:39:55 -0400 Subject: [PATCH 12/17] Remove Unused websocket-client Dependency (#37445) --- sdk/evaluation/azure-ai-evaluation/setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 3d0cdb27873..20a2fca6987 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -67,7 +67,6 @@ setup( install_requires=[ "promptflow-devkit>=1.15.0", "promptflow-core>=1.15.0", - "websocket-client>=1.2.0", "numpy>=1.23.2; python_version<'3.12'", "numpy>=1.26.4; python_version>='3.12'", "pyjwt>=2.8.0", From 19649840a2e8ef1b1d1ae864e1c002e0da795c7b Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 18 Sep 2024 10:54:16 -0700 Subject: [PATCH 13/17] enable bandit (#37415) * enable bandit * remove pin for importlib-metadata --- eng/tox/tox.ini | 1 - sdk/ai/azure-ai-generative/pyproject.toml | 1 - sdk/ai/azure-ai-resources/pyproject.toml | 1 - sdk/evaluation/azure-ai-evaluation/pyproject.toml | 1 - 4 files changed, 4 deletions(-) diff --git a/eng/tox/tox.ini b/eng/tox/tox.ini index 468f11ea28e..0711eebdedf 100644 --- a/eng/tox/tox.ini +++ b/eng/tox/tox.ini @@ -470,7 +470,6 @@ setenv = PROXY_URL=http://localhost:5015 deps = {[base]deps} - importlib-metadata<5.0 commands = python {repository_root}/eng/tox/create_package_and_install.py \ -d {envtmpdir} \ diff --git a/sdk/ai/azure-ai-generative/pyproject.toml b/sdk/ai/azure-ai-generative/pyproject.toml index 094fe3998b5..e1901c46398 100644 --- a/sdk/ai/azure-ai-generative/pyproject.toml +++ b/sdk/ai/azure-ai-generative/pyproject.toml @@ -5,7 +5,6 @@ verifytypes = false pyright = false pylint = true black = false -bandit = false sphinx=true breaking = false whl_no_aio = false diff --git a/sdk/ai/azure-ai-resources/pyproject.toml b/sdk/ai/azure-ai-resources/pyproject.toml index 6ed94c2681c..27ab78fef49 100644 --- a/sdk/ai/azure-ai-resources/pyproject.toml +++ b/sdk/ai/azure-ai-resources/pyproject.toml @@ -5,7 +5,6 @@ verifytypes = false pyright = false pylint = false black = false -bandit = false sphinx = true breaking = false diff --git a/sdk/evaluation/azure-ai-evaluation/pyproject.toml b/sdk/evaluation/azure-ai-evaluation/pyproject.toml index 27f0b158f84..d109f756a75 100644 --- a/sdk/evaluation/azure-ai-evaluation/pyproject.toml +++ b/sdk/evaluation/azure-ai-evaluation/pyproject.toml @@ -3,4 +3,3 @@ mypy = false pyright = false pylint = false black = true -bandit = false From bdceb13a2c3e809920c2b66e37f49ee9dba9269a Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Wed, 18 Sep 2024 14:03:27 -0700 Subject: [PATCH 14/17] Update perf tests to use federated auth (#37446) Co-authored-by: Wes Haggard --- eng/common/pipelines/templates/jobs/perf.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/eng/common/pipelines/templates/jobs/perf.yml b/eng/common/pipelines/templates/jobs/perf.yml index d1204f284ce..71df4bb1a3e 100644 --- a/eng/common/pipelines/templates/jobs/perf.yml +++ b/eng/common/pipelines/templates/jobs/perf.yml @@ -120,6 +120,10 @@ jobs: ServiceDirectory: ${{ parameters.ServiceDirectory }} Location: westus ResourceType: perf + ServiceConnection: azure-sdk-tests + SubscriptionConfigurationFilePaths: + - eng/common/TestResources/sub-config/AzurePublicMsft.json + UseFederatedAuth: true - script: >- dotnet run -- run @@ -179,3 +183,7 @@ jobs: parameters: ServiceDirectory: ${{ parameters.ServiceDirectory }} ResourceType: perf + ServiceConnection: azure-sdk-tests + SubscriptionConfigurationFilePaths: + - eng/common/TestResources/sub-config/AzurePublicMsft.json + UseFederatedAuth: true From eab7a3ab6714d48da3b76ce6daaa50e93a627b43 Mon Sep 17 00:00:00 2001 From: swathipil <76007337+swathipil@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:11:49 -0500 Subject: [PATCH 15/17] [SchemaRegistry] prepare release JSON Custom GA (#37448) * remove b2 refs * update release date --- sdk/schemaregistry/azure-schemaregistry/CHANGELOG.md | 2 +- sdk/schemaregistry/azure-schemaregistry/dev_requirements.txt | 2 +- sdk/schemaregistry/azure-schemaregistry/samples/README.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/schemaregistry/azure-schemaregistry/CHANGELOG.md b/sdk/schemaregistry/azure-schemaregistry/CHANGELOG.md index 2d1cb1faf6e..bdfb6fc4253 100644 --- a/sdk/schemaregistry/azure-schemaregistry/CHANGELOG.md +++ b/sdk/schemaregistry/azure-schemaregistry/CHANGELOG.md @@ -1,6 +1,6 @@ # Release History -## 1.3.0 (2024-09-17) +## 1.3.0 (2024-09-18) This version and all future versions will require Python 3.8+. Python 3.7 is no longer supported. diff --git a/sdk/schemaregistry/azure-schemaregistry/dev_requirements.txt b/sdk/schemaregistry/azure-schemaregistry/dev_requirements.txt index 3d285dac392..6bd2799120e 100644 --- a/sdk/schemaregistry/azure-schemaregistry/dev_requirements.txt +++ b/sdk/schemaregistry/azure-schemaregistry/dev_requirements.txt @@ -1,6 +1,6 @@ -e ../../../tools/azure-sdk-tools -e ../../core/azure-core --e ../../identity/azure-identity +azure-identity==1.17.0 jsonschema>=4.10.3 aiohttp>=3.0 genson diff --git a/sdk/schemaregistry/azure-schemaregistry/samples/README.md b/sdk/schemaregistry/azure-schemaregistry/samples/README.md index 122504eb864..e772362c9d7 100644 --- a/sdk/schemaregistry/azure-schemaregistry/samples/README.md +++ b/sdk/schemaregistry/azure-schemaregistry/samples/README.md @@ -40,7 +40,7 @@ If you do not have an existing Azure account, you may sign up for a free trial o 1. Install the Azure Schema Registry client library and Azure Identity client library for Python with [pip](https://pypi.org/project/pip/): ```bash -pip install azure-schemaregistry==1.3.0b2 +pip install azure-schemaregistry ``` To run samples utilizing the Azure Active Directory for authentication, please install the azure-identity library: @@ -52,7 +52,7 @@ pip install azure-identity To use the built-in `jsonschema` validation for the JSON Schema Encoder, install the Azure Schema Registry client library with `jsonencoder` extras installed: ```bash -pip install azure-schemaregistry[jsonencoder]==1.3.0b2 azure-identity +pip install azure-schemaregistry[jsonencoder] azure-identity ``` Additionally, if using with `azure.eventhub.EventData`, install `azure-eventhub>=5.9.0`: From fa1bc864001e80f5df0bbfe572fa1c039d3df5db Mon Sep 17 00:00:00 2001 From: Neehar Duvvuri <40341266+needuv@users.noreply.github.com> Date: Wed, 18 Sep 2024 17:26:13 -0400 Subject: [PATCH 16/17] Flatten Namespaces for Evaluation SDK (#37398) * remove public evaluators and evaluate modules * rename synthetic to simulator and expose only one namespace * clean up some references * fix some broken imports * add details on breaking change * fix changelog grammar issue * fix changelog grammar issue * attempt at fixing tests * change patch * disable verifytypes --- .../azure-ai-evaluation/CHANGELOG.md | 3 +- sdk/evaluation/azure-ai-evaluation/README.md | 6 +-- .../azure/ai/evaluation/__init__.py | 49 ++++++++++++++++++ .../{evaluate => _evaluate}/__init__.py | 4 -- .../_batch_run_client/__init__.py | 0 .../_batch_run_client/batch_run_context.py | 4 +- .../_batch_run_client/code_client.py | 2 +- .../_batch_run_client/proxy_client.py | 0 .../{evaluate => _evaluate}/_eval_run.py | 6 +-- .../{evaluate => _evaluate}/_evaluate.py | 5 +- .../_telemetry/__init__.py | 4 +- .../{evaluate => _evaluate}/_utils.py | 2 +- .../ai/evaluation/_evaluators/__init__.py | 3 ++ .../_bleu/__init__.py | 0 .../_bleu/_bleu.py | 0 .../_chat/__init__.py | 0 .../_chat/_chat.py | 0 .../_chat/retrieval/__init__.py | 0 .../_chat/retrieval/_retrieval.py | 0 .../_chat/retrieval/retrieval.prompty | 0 .../_coherence/__init__.py | 0 .../_coherence/_coherence.py | 0 .../_coherence/coherence.prompty | 0 .../_content_safety/__init__.py | 0 .../_content_safety/_content_safety.py | 0 .../_content_safety/_content_safety_base.py | 2 +- .../_content_safety/_content_safety_chat.py | 0 .../_content_safety/_hate_unfairness.py | 0 .../_content_safety/_self_harm.py | 0 .../_content_safety/_sexual.py | 0 .../_content_safety/_violence.py | 0 .../_eci/__init__.py | 0 .../{evaluators => _evaluators}/_eci/_eci.py | 0 .../_f1_score/__init__.py | 0 .../_f1_score/_f1_score.py | 0 .../_fluency/__init__.py | 0 .../_fluency/_fluency.py | 0 .../_fluency/fluency.prompty | 0 .../_gleu/__init__.py | 0 .../_gleu/_gleu.py | 0 .../_groundedness/__init__.py | 0 .../_groundedness/_groundedness.py | 0 .../_groundedness/groundedness.prompty | 0 .../_meteor/__init__.py | 0 .../_meteor/_meteor.py | 0 .../_protected_material/__init__.py | 0 .../_protected_material.py | 0 .../_protected_materials/__init__.py | 0 .../_protected_materials.py | 0 .../_qa/__init__.py | 0 .../{evaluators => _evaluators}/_qa/_qa.py | 0 .../_relevance/__init__.py | 0 .../_relevance/_relevance.py | 0 .../_relevance/relevance.prompty | 0 .../_rouge/__init__.py | 0 .../_rouge/_rouge.py | 0 .../_similarity/__init__.py | 0 .../_similarity/_similarity.py | 0 .../_similarity/similarity.prompty | 0 .../_xpia/__init__.py | 0 .../{evaluators => _evaluators}/_xpia/xpia.py | 0 .../ai/evaluation/evaluators/__init__.py | 50 ------------------- .../azure/ai/evaluation/simulator/__init__.py | 13 +++++ .../_adversarial_scenario.py} | 0 .../_adversarial_simulator.py} | 13 ++--- .../constants.py => simulator/_constants.py} | 0 .../_conversation/__init__.py | 8 +-- .../_conversation/_conversation.py | 6 +-- .../_conversation/constants.py | 0 .../_direct_attack_simulator.py} | 10 ++-- .../_helpers/__init__.py | 0 .../_helpers/_language_suffix_mapping.py | 2 +- .../_indirect_attack_simulator.py} | 6 +-- .../_model_tools/__init__.py | 0 .../_model_tools/_identity_manager.py | 2 +- .../_model_tools/_proxy_completion_model.py | 0 .../_model_tools/_rai_client.py | 2 +- .../_model_tools/_template_handler.py | 4 +- .../_model_tools/models.py | 0 .../{synthetic => simulator}/_utils.py | 0 .../azure/ai/evaluation/synthetic/__init__.py | 13 ----- .../azure-ai-evaluation/pyproject.toml | 3 +- .../tests/e2etests/test_adv_simulator.py | 28 +++++------ .../tests/e2etests/test_builtin_evaluators.py | 4 +- .../tests/e2etests/test_evaluate.py | 3 +- .../tests/e2etests/test_metrics_upload.py | 14 +++--- .../tests/unittests/test_batch_run_context.py | 2 +- .../unittests/test_built_in_evaluator.py | 2 +- .../tests/unittests/test_chat_evaluator.py | 2 +- .../test_content_safety_chat_evaluator.py | 2 +- .../test_content_safety_defect_rate.py | 4 +- .../tests/unittests/test_eval_run.py | 8 +-- .../tests/unittests/test_evaluate.py | 12 ++--- .../unittests/test_evaluate_telemetry.py | 4 +- .../unittests/test_jailbreak_simulator.py | 22 ++++---- .../tests/unittests/test_save_eval.py | 4 +- .../tests/unittests/test_simulator.py | 22 ++++---- .../test_synthetic_callback_conv_bot.py | 2 +- .../test_synthetic_conversation_bot.py | 2 +- 99 files changed, 178 insertions(+), 181 deletions(-) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/__init__.py (75%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_batch_run_client/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_batch_run_client/batch_run_context.py (94%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_batch_run_client/code_client.py (98%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_batch_run_client/proxy_client.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_eval_run.py (98%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_evaluate.py (99%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_telemetry/__init__.py (99%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluate => _evaluate}/_utils.py (99%) create mode 100644 sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/__init__.py rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_bleu/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_bleu/_bleu.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_chat/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_chat/_chat.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_chat/retrieval/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_chat/retrieval/_retrieval.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_chat/retrieval/retrieval.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_coherence/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_coherence/_coherence.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_coherence/coherence.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_content_safety.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_content_safety_base.py (95%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_content_safety_chat.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_hate_unfairness.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_self_harm.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_sexual.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_content_safety/_violence.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_eci/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_eci/_eci.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_f1_score/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_f1_score/_f1_score.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_fluency/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_fluency/_fluency.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_fluency/fluency.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_gleu/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_gleu/_gleu.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_groundedness/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_groundedness/_groundedness.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_groundedness/groundedness.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_meteor/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_meteor/_meteor.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_protected_material/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_protected_material/_protected_material.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_protected_materials/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_protected_materials/_protected_materials.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_qa/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_qa/_qa.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_relevance/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_relevance/_relevance.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_relevance/relevance.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_rouge/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_rouge/_rouge.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_similarity/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_similarity/_similarity.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_similarity/similarity.prompty (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_xpia/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{evaluators => _evaluators}/_xpia/xpia.py (100%) delete mode 100644 sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/__init__.py create mode 100644 sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/__init__.py rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic/adversarial_scenario.py => simulator/_adversarial_scenario.py} (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic/adversarial_simulator.py => simulator/_adversarial_simulator.py} (97%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic/constants.py => simulator/_constants.py} (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_conversation/__init__.py (97%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_conversation/_conversation.py (96%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_conversation/constants.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic/direct_attack_simulator.py => simulator/_direct_attack_simulator.py} (95%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_helpers/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_helpers/_language_suffix_mapping.py (92%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic/indirect_attack_simulator.py => simulator/_indirect_attack_simulator.py} (97%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/__init__.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/_identity_manager.py (98%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/_proxy_completion_model.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/_rai_client.py (98%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/_template_handler.py (97%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_model_tools/models.py (100%) rename sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/{synthetic => simulator}/_utils.py (100%) delete mode 100644 sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/CHANGELOG.md b/sdk/evaluation/azure-ai-evaluation/CHANGELOG.md index 73bf1f76196..a3757b97025 100644 --- a/sdk/evaluation/azure-ai-evaluation/CHANGELOG.md +++ b/sdk/evaluation/azure-ai-evaluation/CHANGELOG.md @@ -2,9 +2,10 @@ ## 1.0.0b1 (Unreleased) - ### Breaking Changes +- The `synthetic` namespace has been renamed to `simulator`, and sub-namespaces under this module have been removed +- The `evaluate` and `evaluators` namespaces have been removed, and everything previously exposed in those modules has been added to the root namespace `azure.ai.evaluation` - The parameter name `project_scope` in content safety evaluators have been renamed to `azure_ai_project` for consistency with evaluate API and simulators. diff --git a/sdk/evaluation/azure-ai-evaluation/README.md b/sdk/evaluation/azure-ai-evaluation/README.md index eca50e36762..72e5e5e7dc6 100644 --- a/sdk/evaluation/azure-ai-evaluation/README.md +++ b/sdk/evaluation/azure-ai-evaluation/README.md @@ -25,9 +25,7 @@ from pprint import pprint from promptflow.core import AzureOpenAIModelConfiguration -from azure.ai.evaluation.evaluate import evaluate -from azure.ai.evaluation.evaluators import RelevanceEvaluator -from azure.ai.evaluation.evaluators.content_safety import ViolenceEvaluator +from azure.ai.evaluation import evaluate, RelevanceEvaluator, ViolenceEvaluator def answer_length(answer, **kwargs): @@ -97,7 +95,7 @@ Simulator expects the user to have a callback method that invokes their AI appli Here's a sample of a callback which invokes AsyncAzureOpenAI: ```python -from from azure.ai.evaluation.synthetic import AdversarialSimulator, AdversarialScenario +from from azure.ai.evaluation.simulator import AdversarialSimulator, AdversarialScenario from azure.identity import DefaultAzureCredential from typing import Any, Dict, List, Optional import asyncio diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py index d540fd20468..945d2c84e7f 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py @@ -1,3 +1,52 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- + +from ._evaluate._evaluate import evaluate +from ._evaluators._bleu import BleuScoreEvaluator +from ._evaluators._chat import ChatEvaluator +from ._evaluators._coherence import CoherenceEvaluator +from ._evaluators._content_safety import ( + ContentSafetyChatEvaluator, + ContentSafetyEvaluator, + HateUnfairnessEvaluator, + SelfHarmEvaluator, + SexualEvaluator, + ViolenceEvaluator, +) +from ._evaluators._f1_score import F1ScoreEvaluator +from ._evaluators._fluency import FluencyEvaluator +from ._evaluators._gleu import GleuScoreEvaluator +from ._evaluators._groundedness import GroundednessEvaluator +from ._evaluators._meteor import MeteorScoreEvaluator +from ._evaluators._protected_material import ProtectedMaterialEvaluator +from ._evaluators._qa import QAEvaluator +from ._evaluators._relevance import RelevanceEvaluator +from ._evaluators._rouge import RougeScoreEvaluator, RougeType +from ._evaluators._similarity import SimilarityEvaluator +from ._evaluators._xpia import IndirectAttackEvaluator + +__all__ = [ + "evaluate", + "CoherenceEvaluator", + "F1ScoreEvaluator", + "FluencyEvaluator", + "GroundednessEvaluator", + "RelevanceEvaluator", + "SimilarityEvaluator", + "QAEvaluator", + "ChatEvaluator", + "ViolenceEvaluator", + "SexualEvaluator", + "SelfHarmEvaluator", + "HateUnfairnessEvaluator", + "ContentSafetyEvaluator", + "ContentSafetyChatEvaluator", + "IndirectAttackEvaluator", + "BleuScoreEvaluator", + "GleuScoreEvaluator", + "MeteorScoreEvaluator", + "RougeScoreEvaluator", + "RougeType", + "ProtectedMaterialEvaluator", +] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/__init__.py similarity index 75% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/__init__.py index f04582e12f0..d540fd20468 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/__init__.py @@ -1,7 +1,3 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- - -from ._evaluate import evaluate - -__all__ = ["evaluate"] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/batch_run_context.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/batch_run_context.py similarity index 94% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/batch_run_context.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/batch_run_context.py index b8bbc7139c8..c8a82d9d53a 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/batch_run_context.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/batch_run_context.py @@ -24,8 +24,8 @@ class BatchRunContext: :param client: The client to run in the context. :type client: Union[ - ~azure.ai.evaluation.evaluate.code_client.CodeClient, - ~azure.ai.evaluation.evaluate.proxy_client.ProxyClient + ~azure.ai.evaluation._evaluate._batch_run_client.code_client.CodeClient, + ~azure.ai.evaluation._evaluate._batch_run_client.proxy_client.ProxyClient ] """ diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/code_client.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/code_client.py similarity index 98% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/code_client.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/code_client.py index 2393e3c17a3..ecb48eb903c 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/code_client.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/code_client.py @@ -8,7 +8,7 @@ import logging import pandas as pd from promptflow.contracts.types import AttrDict -from azure.ai.evaluation.evaluate._utils import _apply_column_mapping, _has_aggregator, get_int_env_var, load_jsonl +from azure.ai.evaluation._evaluate._utils import _apply_column_mapping, _has_aggregator, get_int_env_var, load_jsonl from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor from ..._constants import PF_BATCH_TIMEOUT_SEC, PF_BATCH_TIMEOUT_SEC_DEFAULT diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/proxy_client.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/proxy_client.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_batch_run_client/proxy_client.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_batch_run_client/proxy_client.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_eval_run.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_eval_run.py similarity index 98% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_eval_run.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_eval_run.py index 3b36bb57017..dec4d87a712 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_eval_run.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_eval_run.py @@ -59,7 +59,7 @@ class RunInfo: :param run_name: The name of a run. :type run_name: Optional[str] :return: The RunInfo instance. - :rtype: azure.ai.evaluation.evaluate.RunInfo + :rtype: azure.ai.evaluation._evaluate._eval_run.RunInfo """ return RunInfo(str(uuid.uuid4()), str(uuid.uuid4()), run_name or "") @@ -229,7 +229,7 @@ class EvalRun(contextlib.AbstractContextManager): # pylint: disable=too-many-in """The Context Manager enter call. :return: The instance of the class. - :rtype: azure.ai.evaluation.evaluate.EvalRun + :rtype: azure.ai.evaluation._evaluate._eval_run.EvalRun """ self._start_run() return self @@ -360,7 +360,7 @@ class EvalRun(contextlib.AbstractContextManager): # pylint: disable=too-many-in :param artifact_folder: The folder with artifacts to be uploaded. :type artifact_folder: str :param artifact_name: The name of the artifact to be uploaded. Defaults to - azure.ai.evaluation.evaluate.EvalRun.EVALUATION_ARTIFACT. + azure.ai.evaluation._evaluate._eval_run.EvalRun.EVALUATION_ARTIFACT. :type artifact_name: str """ if not self._check_state_and_log("log artifact", {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False): diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_evaluate.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_evaluate.py similarity index 99% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_evaluate.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_evaluate.py index 653d352d608..e32666f1961 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_evaluate.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_evaluate.py @@ -60,7 +60,7 @@ def _aggregate_content_safety_metrics( module = inspect.getmodule(evaluators[evaluator_name]) if ( module - and module.__name__.startswith("azure.ai.evaluation.evaluators.") + and module.__name__.startswith("azure.ai.evaluation.") and metric_name.endswith("_score") and metric_name.replace("_score", "") in content_safety_metrics ): @@ -397,8 +397,7 @@ def evaluate( .. code-block:: python from promptflow.core import AzureOpenAIModelConfiguration - from azure.ai.evaluation.evaluate import evaluate - from azure.ai.evaluation.evaluators import RelevanceEvaluator, CoherenceEvaluator + from azure.ai.evaluation import evaluate, RelevanceEvaluator, CoherenceEvaluator model_config = AzureOpenAIModelConfiguration( diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_telemetry/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_telemetry/__init__.py similarity index 99% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_telemetry/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_telemetry/__init__.py index 09fd4927b3a..dd403c957d3 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_telemetry/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_telemetry/__init__.py @@ -35,9 +35,9 @@ def _get_evaluator_type(evaluator: Dict[str, Callable]): content_safety = False module = inspect.getmodule(evaluator) - built_in = module and module.__name__.startswith("azure.ai.evaluation.evaluators.") + built_in = module and module.__name__.startswith("azure.ai.evaluation._evaluators.") if built_in: - content_safety = module.__name__.startswith("azure.ai.evaluation.evaluators._content_safety") + content_safety = module.__name__.startswith("azure.ai.evaluation._evaluators._content_safety") if content_safety: return "content-safety" diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_utils.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py similarity index 99% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_utils.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py index 7dd3c30d924..db29faac6aa 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluate/_utils.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_utils.py @@ -12,7 +12,7 @@ from pathlib import Path import pandas as pd from azure.ai.evaluation._constants import DEFAULT_EVALUATION_RESULTS_FILE_NAME, Prefixes -from azure.ai.evaluation.evaluate._eval_run import EvalRun +from azure.ai.evaluation._evaluate._eval_run import EvalRun LOGGER = logging.getLogger(__name__) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/__init__.py new file mode 100644 index 00000000000..d540fd20468 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_bleu/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_bleu/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_bleu/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_bleu/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_bleu/_bleu.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_bleu/_bleu.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_bleu/_bleu.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_bleu/_bleu.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/_chat.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/_chat.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/_chat.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/_chat.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/_retrieval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/_retrieval.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/retrieval.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_chat/retrieval/retrieval.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/_coherence.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/_coherence.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/coherence.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/coherence.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_coherence/coherence.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/coherence.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety_base.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py similarity index 95% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety_base.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py index 728a6864739..38fc51e7dc3 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety_base.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py @@ -15,7 +15,7 @@ class ContentSafetyEvaluatorBase(ABC): :param metric: The metric to be evaluated. - :type metric: ~azure.ai.evaluation.evaluators._content_safety.flow.constants.EvaluationMetrics + :type metric: ~azure.ai.evaluation._evaluators._content_safety.flow.constants.EvaluationMetrics :param azure_ai_project: The scope of the Azure AI project. It contains subscription id, resource group, and project name. :type azure_ai_project: Dict diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety_chat.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_content_safety_chat.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_hate_unfairness.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_hate_unfairness.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_self_harm.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_self_harm.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_sexual.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_sexual.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_violence.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_content_safety/_violence.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_eci/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_eci/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_eci/_eci.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_eci/_eci.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_f1_score/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_f1_score/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_f1_score/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_f1_score/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_f1_score/_f1_score.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_f1_score/_f1_score.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/_fluency.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/_fluency.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/fluency.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/fluency.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_fluency/fluency.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/fluency.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_gleu/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_gleu/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_gleu/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_gleu/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_gleu/_gleu.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_gleu/_gleu.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_gleu/_gleu.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_gleu/_gleu.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/_groundedness.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/_groundedness.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/groundedness.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_groundedness/groundedness.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_meteor/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_meteor/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_meteor/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_meteor/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_meteor/_meteor.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_meteor/_meteor.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_meteor/_meteor.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_meteor/_meteor.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_material/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_material/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_material/_protected_material.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_material/_protected_material.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_materials/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_materials/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_materials/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_materials/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_materials/_protected_materials.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_protected_materials/_protected_materials.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_qa/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_qa/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_qa/_qa.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_qa/_qa.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/_relevance.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/_relevance.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/relevance.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/relevance.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_relevance/relevance.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/relevance.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_rouge/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_rouge/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_rouge/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_rouge/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_rouge/_rouge.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_rouge/_rouge.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_rouge/_rouge.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_rouge/_rouge.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/_similarity.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/_similarity.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/_similarity.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/_similarity.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/similarity.prompty b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/similarity.prompty similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_similarity/similarity.prompty rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_similarity/similarity.prompty diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_xpia/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_xpia/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_xpia/xpia.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/_xpia/xpia.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/__init__.py deleted file mode 100644 index a4492032d9b..00000000000 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/evaluators/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- - -from ._bleu import BleuScoreEvaluator -from ._chat import ChatEvaluator -from ._coherence import CoherenceEvaluator -from ._content_safety import ( - ContentSafetyChatEvaluator, - ContentSafetyEvaluator, - HateUnfairnessEvaluator, - SelfHarmEvaluator, - SexualEvaluator, - ViolenceEvaluator, -) -from ._f1_score import F1ScoreEvaluator -from ._fluency import FluencyEvaluator -from ._gleu import GleuScoreEvaluator -from ._groundedness import GroundednessEvaluator -from ._meteor import MeteorScoreEvaluator -from ._protected_material import ProtectedMaterialEvaluator -from ._qa import QAEvaluator -from ._relevance import RelevanceEvaluator -from ._rouge import RougeScoreEvaluator, RougeType -from ._similarity import SimilarityEvaluator -from ._xpia import IndirectAttackEvaluator - -__all__ = [ - "CoherenceEvaluator", - "F1ScoreEvaluator", - "FluencyEvaluator", - "GroundednessEvaluator", - "RelevanceEvaluator", - "SimilarityEvaluator", - "QAEvaluator", - "ChatEvaluator", - "ViolenceEvaluator", - "SexualEvaluator", - "SelfHarmEvaluator", - "HateUnfairnessEvaluator", - "ContentSafetyEvaluator", - "ContentSafetyChatEvaluator", - "IndirectAttackEvaluator", - "BleuScoreEvaluator", - "GleuScoreEvaluator", - "MeteorScoreEvaluator", - "RougeScoreEvaluator", - "RougeType", - "ProtectedMaterialEvaluator", -] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/__init__.py new file mode 100644 index 00000000000..ac620a5cf8f --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/__init__.py @@ -0,0 +1,13 @@ +from ._adversarial_scenario import AdversarialScenario +from ._adversarial_simulator import AdversarialSimulator +from ._constants import SupportedLanguages +from ._direct_attack_simulator import DirectAttackSimulator +from ._indirect_attack_simulator import IndirectAttackSimulator + +__all__ = [ + "AdversarialSimulator", + "AdversarialScenario", + "DirectAttackSimulator", + "IndirectAttackSimulator", + "SupportedLanguages", +] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/adversarial_scenario.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_scenario.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/adversarial_scenario.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_scenario.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/adversarial_simulator.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_simulator.py similarity index 97% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/adversarial_simulator.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_simulator.py index 2903512a997..c627ee206da 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/adversarial_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_adversarial_simulator.py @@ -15,7 +15,8 @@ from tqdm import tqdm from promptflow._sdk._telemetry import ActivityType, monitor_operation from azure.ai.evaluation._http_utils import get_async_http_client -from azure.ai.evaluation.synthetic.adversarial_scenario import AdversarialScenario, _UnstableAdversarialScenario +from azure.ai.evaluation.simulator import AdversarialScenario +from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario from ._conversation import CallbackConversationBot, ConversationBot, ConversationRole from ._conversation._conversation import simulate_conversation @@ -27,7 +28,7 @@ from ._model_tools import ( TokenScope, ) from ._utils import JsonLineList -from .constants import SupportedLanguages +from ._constants import SupportedLanguages logger = logging.getLogger(__name__) @@ -128,9 +129,9 @@ class AdversarialSimulator: :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs. example: - - :py:const:`azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario.ADVERSARIAL_QA` - - :py:const:`azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario.ADVERSARIAL_CONVERSATION` - :paramtype scenario: azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario + - :py:const:`azure.ai.evaluation.simulator.AdversarialScenario.ADVERSARIAL_QA` + - :py:const:`azure.ai.evaluation.simulator.AdversarialScenario.ADVERSARIAL_CONVERSATION` + :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario :keyword target: The target function to simulate adversarial inputs against. This function should be asynchronous and accept a dictionary representing the adversarial input. :paramtype target: Callable @@ -153,7 +154,7 @@ class AdversarialSimulator: Defaults to 3. :paramtype concurrent_async_task: int :keyword language: The language in which the conversation should be generated. Defaults to English. - :paramtype language: azure.ai.evaluation.synthetic.constants.SupportedLanguages + :paramtype language: azure.ai.evaluation.simulator.SupportedLanguages :keyword randomize_order: Whether or not the order of the prompts should be randomized. Defaults to True. :paramtype randomize_order: bool :keyword randomization_seed: The seed used to randomize prompt selection. If unset, the system's diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/constants.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_constants.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/constants.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_constants.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/__init__.py similarity index 97% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/__init__.py index a96f355e424..17e327dbd15 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/__init__.py @@ -25,7 +25,7 @@ class ConversationTurn: :param role: The role of the participant in the conversation. Accepted values are "user" and "assistant". - :type role: ~azure.ai.evaluation.synthetic._conversation.constants.ConversationRole + :type role: ~azure.ai.evaluation.simulator._conversation.constants.ConversationRole :param name: The name of the participant in the conversation. :type name: Optional[str] :param message: The message exchanged in the conversation. Defaults to an empty string. @@ -90,11 +90,11 @@ class ConversationBot: A conversation chat bot with a specific name, persona and a sentence that can be used as a conversation starter. :param role: The role of the bot in the conversation, either "user" or "assistant". - :type role: ~azure.ai.evaluation.synthetic._conversation.constants.ConversationRole + :type role: ~azure.ai.evaluation.simulator._conversation.constants.ConversationRole :param model: The LLM model to use for generating responses. :type model: Union[ - ~azure.ai.evaluation.synthetic._model_tools.LLMBase, - ~azure.ai.evaluation.synthetic._model_tools.OpenAIChatCompletionsModel + ~azure.ai.evaluation.simulator._model_tools.LLMBase, + ~azure.ai.evaluation.simulator._model_tools.OpenAIChatCompletionsModel ] :param conversation_template: A Jinja2 template describing the conversation to generate the prompt for the LLM :type conversation_template: str diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/_conversation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/_conversation.py similarity index 96% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/_conversation.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/_conversation.py index c511f1b440f..f84dfa0272e 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/_conversation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/_conversation.py @@ -6,8 +6,8 @@ import asyncio import logging from typing import Callable, Dict, List, Tuple, Union -from azure.ai.evaluation.synthetic._helpers._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING -from azure.ai.evaluation.synthetic.constants import SupportedLanguages +from azure.ai.evaluation.simulator._helpers._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING +from azure.ai.evaluation.simulator._constants import SupportedLanguages from ..._http_utils import AsyncHttpPipeline from . import ConversationBot, ConversationTurn @@ -110,7 +110,7 @@ async def simulate_conversation( if not isinstance(language, SupportedLanguages) or language not in SupportedLanguages: raise Exception( # pylint: disable=broad-exception-raised f"Language option '{language}' isn't supported. Select a supported language option from " - f"azure.ai.evaluation.synthetic._constants.SupportedLanguages: {[f'{e}' for e in SupportedLanguages]}" + f"azure.ai.evaluation.simulator.SupportedLanguages: {[f'{e}' for e in SupportedLanguages]}" ) first_prompt += f" {SUPPORTED_LANGUAGES_MAPPING[language]}" # Add all generated turns into array to pass for each bot while generating diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/constants.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/constants.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_conversation/constants.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_conversation/constants.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/direct_attack_simulator.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_direct_attack_simulator.py similarity index 95% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/direct_attack_simulator.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_direct_attack_simulator.py index 2666b9d057d..047a7e4b660 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/direct_attack_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_direct_attack_simulator.py @@ -10,10 +10,10 @@ from typing import Any, Callable, Dict, Optional from azure.identity import DefaultAzureCredential from promptflow._sdk._telemetry import ActivityType, monitor_operation -from azure.ai.evaluation.synthetic.adversarial_scenario import AdversarialScenario +from azure.ai.evaluation.simulator import AdversarialScenario from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope -from .adversarial_simulator import AdversarialSimulator +from ._adversarial_simulator import AdversarialSimulator logger = logging.getLogger(__name__) @@ -113,9 +113,9 @@ class DirectAttackSimulator: :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs. example: - - :py:const:`azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario.ADVERSARIAL_QA` - - :py:const:`azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario.ADVERSARIAL_CONVERSATION` - :paramtype scenario: azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario + - :py:const:`azure.ai.evaluation.simulator.AdversarialScenario.ADVERSARIAL_QA` + - :py:const:`azure.ai.evaluation.simulator.AdversarialScenario.ADVERSARIAL_CONVERSATION` + :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario :keyword target: The target function to simulate adversarial inputs against. This function should be asynchronous and accept a dictionary representing the adversarial input. :paramtype target: Callable diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_helpers/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_helpers/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_helpers/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_helpers/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_helpers/_language_suffix_mapping.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py similarity index 92% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_helpers/_language_suffix_mapping.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py index 5e808898d5d..f5a5e250856 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_helpers/_language_suffix_mapping.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from azure.ai.evaluation.synthetic.constants import SupportedLanguages +from azure.ai.evaluation.simulator._constants import SupportedLanguages BASE_SUFFIX = "Make the conversation in __language__ language." diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/indirect_attack_simulator.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_indirect_attack_simulator.py similarity index 97% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/indirect_attack_simulator.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_indirect_attack_simulator.py index 564e9829faa..f77b2138727 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/indirect_attack_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_indirect_attack_simulator.py @@ -9,10 +9,10 @@ from typing import Any, Callable, Dict from azure.identity import DefaultAzureCredential from promptflow._sdk._telemetry import ActivityType, monitor_operation -from azure.ai.evaluation.synthetic.adversarial_scenario import AdversarialScenario +from azure.ai.evaluation.simulator import AdversarialScenario from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope -from .adversarial_simulator import AdversarialSimulator +from ._adversarial_simulator import AdversarialSimulator logger = logging.getLogger(__name__) @@ -109,7 +109,7 @@ class IndirectAttackSimulator: the scope of your AI system. :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs. - :paramtype scenario: azure.ai.evaluation.synthetic.adversarial_scenario.AdversarialScenario + :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario :keyword target: The target function to simulate adversarial inputs against. This function should be asynchronous and accept a dictionary representing the adversarial input. :paramtype target: Callable diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/__init__.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/__init__.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/__init__.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_identity_manager.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py similarity index 98% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_identity_manager.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py index c17c1d26253..162a5e31e5c 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_identity_manager.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_identity_manager.py @@ -93,7 +93,7 @@ class ManagedIdentityAPITokenManager(APITokenManager): """API Token Manager for Azure Managed Identity :param token_scope: Token scope for Azure endpoint - :type token_scope: ~azure.ai.evaluation.synthetic._model_tools.TokenScope + :type token_scope: ~azure.ai.evaluation.simulator._model_tools.TokenScope :param logger: Logger object :type logger: logging.Logger :keyword kwargs: Additional keyword arguments diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_proxy_completion_model.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_proxy_completion_model.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_rai_client.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_rai_client.py similarity index 98% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_rai_client.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_rai_client.py index 6cd062bb8d6..ad7a306d792 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_rai_client.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_rai_client.py @@ -25,7 +25,7 @@ class RAIClient: :param azure_ai_project: The Azure AI project :type azure_ai_project: Dict :param token_manager: The token manager - :type token_manage: ~azure.ai.evaluation.synthetic._model_tools._identity_manager.APITokenManager + :type token_manage: ~azure.ai.evaluation.simulator._model_tools._identity_manager.APITokenManager """ def __init__(self, azure_ai_project: Dict, token_manager: APITokenManager) -> None: diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_template_handler.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_template_handler.py similarity index 97% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_template_handler.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_template_handler.py index 985bca3f778..06cf6a7bd34 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/_template_handler.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/_template_handler.py @@ -100,7 +100,7 @@ class AdversarialTemplateHandler: :param azure_ai_project: The Azure AI project. :type azure_ai_project: Dict[str, Any] :param rai_client: The RAI client. - :type rai_client: ~azure.ai.evaluation.synthetic._model_tools.RAIClient + :type rai_client: ~azure.ai.evaluation.simulator._model_tools.RAIClient """ def __init__(self, azure_ai_project: Dict[str, Any], rai_client: RAIClient) -> None: @@ -148,7 +148,7 @@ class AdversarialTemplateHandler: :param template_name: The name of the template. :type template_name: str :return: The generated content harm template. - :rtype: Optional[~azure.ai.evaluation.synthetic._model_tools.AdversarialTemplate] + :rtype: Optional[~azure.ai.evaluation.simulator._model_tools.AdversarialTemplate] """ if template_name in CONTENT_HARM_TEMPLATES_COLLECTION_KEY: return AdversarialTemplate(template_name=template_name, text=None, context_key=[], template_parameters=None) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/models.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/models.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_model_tools/models.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_model_tools/models.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_utils.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_utils.py similarity index 100% rename from sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/_utils.py rename to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/simulator/_utils.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/__init__.py deleted file mode 100644 index c542d3db062..00000000000 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/synthetic/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .adversarial_scenario import AdversarialScenario -from .adversarial_simulator import AdversarialSimulator -from .constants import SupportedLanguages -from .direct_attack_simulator import DirectAttackSimulator -from .indirect_attack_simulator import IndirectAttackSimulator - -__all__ = [ - "AdversarialSimulator", - "AdversarialScenario", - "DirectAttackSimulator", - "IndirectAttackSimulator", - "SupportedLanguages", -] diff --git a/sdk/evaluation/azure-ai-evaluation/pyproject.toml b/sdk/evaluation/azure-ai-evaluation/pyproject.toml index d109f756a75..a2311bae532 100644 --- a/sdk/evaluation/azure-ai-evaluation/pyproject.toml +++ b/sdk/evaluation/azure-ai-evaluation/pyproject.toml @@ -2,4 +2,5 @@ mypy = false pyright = false pylint = false -black = true +bandit = false +verifytypes = false diff --git a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_adv_simulator.py b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_adv_simulator.py index e7c203ccbb7..c55843e24be 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_adv_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_adv_simulator.py @@ -11,7 +11,7 @@ from devtools_testutils import is_live class TestAdvSimulator: def test_adv_sim_init_with_prod_url(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -24,7 +24,7 @@ class TestAdvSimulator: def test_incorrect_scenario_raises_error(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -49,7 +49,7 @@ class TestAdvSimulator: def test_adv_qa_sim_responds_with_one_response(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -100,7 +100,7 @@ class TestAdvSimulator: @pytest.mark.skip(reason="Temporary skip to merge 37201, will re-enable in subsequent pr") def test_adv_conversation_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -142,7 +142,7 @@ class TestAdvSimulator: def test_adv_summarization_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -183,7 +183,7 @@ class TestAdvSimulator: def test_adv_summarization_jailbreak_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -225,7 +225,7 @@ class TestAdvSimulator: def test_adv_rewrite_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -267,7 +267,7 @@ class TestAdvSimulator: @pytest.mark.skipif(is_live(), reason="API not fully released yet. Don't run in live mode unless connected to INT.") def test_adv_protected_matierial_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -308,8 +308,8 @@ class TestAdvSimulator: @pytest.mark.skipif(is_live(), reason="API not fully released yet. Don't run in live mode unless connected to INT.") def test_adv_eci_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialSimulator - from azure.ai.evaluation.synthetic.adversarial_scenario import _UnstableAdversarialScenario + from azure.ai.evaluation.simulator import AdversarialSimulator + from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -353,7 +353,7 @@ class TestAdvSimulator: ) def test_adv_xpia_sim_responds_with_responses(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, IndirectAttackSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, IndirectAttackSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -392,7 +392,7 @@ class TestAdvSimulator: ) def test_adv_sim_order_randomness_with_jailbreak(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -469,7 +469,7 @@ class TestAdvSimulator: ) def test_adv_sim_order_randomness(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], @@ -543,7 +543,7 @@ class TestAdvSimulator: ) def test_jailbreak_sim_order_randomness(self, azure_cred, project_scope): os.environ.pop("RAI_SVC_URL", None) - from azure.ai.evaluation.synthetic import AdversarialScenario, DirectAttackSimulator + from azure.ai.evaluation.simulator import AdversarialScenario, DirectAttackSimulator azure_ai_project = { "subscription_id": project_scope["subscription_id"], diff --git a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py index 9374db0ca9b..5755cddb67c 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py @@ -4,7 +4,7 @@ import numpy as np import pytest from devtools_testutils import is_live -from azure.ai.evaluation.evaluators import ( +from azure.ai.evaluation import ( BleuScoreEvaluator, ChatEvaluator, CoherenceEvaluator, @@ -27,7 +27,7 @@ from azure.ai.evaluation.evaluators import ( SimilarityEvaluator, ViolenceEvaluator, ) -from azure.ai.evaluation.evaluators._eci._eci import ECIEvaluator +from azure.ai.evaluation._evaluators._eci._eci import ECIEvaluator @pytest.mark.usefixtures("recording_injection", "recorded_test") diff --git a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_evaluate.py b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_evaluate.py index 8b071349496..d6c0a7e0aed 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_evaluate.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_evaluate.py @@ -9,8 +9,7 @@ import pytest import requests from ci_tools.variables import in_ci -from azure.ai.evaluation.evaluate import evaluate -from azure.ai.evaluation.evaluators import ( +from azure.ai.evaluation import ( ContentSafetyEvaluator, F1ScoreEvaluator, FluencyEvaluator, diff --git a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_metrics_upload.py b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_metrics_upload.py index da4c8cdf775..74469c38566 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_metrics_upload.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_metrics_upload.py @@ -8,10 +8,10 @@ import pytest from devtools_testutils import is_live from promptflow.tracing import _start_trace -from azure.ai.evaluation.evaluate import _utils as ev_utils -from azure.ai.evaluation.evaluate._eval_run import EvalRun -from azure.ai.evaluation.evaluate._evaluate import evaluate -from azure.ai.evaluation.evaluators._f1_score._f1_score import F1ScoreEvaluator +from azure.ai.evaluation._evaluate import _utils as ev_utils +from azure.ai.evaluation._evaluate._eval_run import EvalRun +from azure.ai.evaluation._evaluate._evaluate import evaluate +from azure.ai.evaluation import F1ScoreEvaluator @pytest.fixture @@ -72,7 +72,7 @@ class TestMetricsUpload(object): workspace_name=project_scope["project_name"], ml_client=azure_ml_client, ) as ev_run: - with patch("azure.ai.evaluation.evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): + with patch("azure.ai.evaluation._evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): ev_run.write_properties_to_run_history({"test": 42}) assert any( lg_rec.levelno == logging.ERROR for lg_rec in caplog.records @@ -100,7 +100,7 @@ class TestMetricsUpload(object): ) as ev_run: mock_response = MagicMock() mock_response.status_code = 418 - with patch("azure.ai.evaluation.evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): + with patch("azure.ai.evaluation._evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): ev_run.log_metric("f1", 0.54) assert any( lg_rec.levelno == logging.WARNING for lg_rec in caplog.records @@ -133,7 +133,7 @@ class TestMetricsUpload(object): os.makedirs(os.path.join(tmp_path, "internal_dir"), exist_ok=True) with open(os.path.join(tmp_path, "internal_dir", "test.json"), "w") as fp: json.dump({"internal_f1": 0.6}, fp) - with patch("azure.ai.evaluation.evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): + with patch("azure.ai.evaluation._evaluate._eval_run.EvalRun.request_with_retry", return_value=mock_response): ev_run.log_artifact(tmp_path) assert any( lg_rec.levelno == logging.WARNING for lg_rec in caplog.records diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_batch_run_context.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_batch_run_context.py index 22a739332a4..88113e3dd36 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_batch_run_context.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_batch_run_context.py @@ -6,7 +6,7 @@ from promptflow.client import PFClient from azure.ai.evaluation._constants import PF_BATCH_TIMEOUT_SEC, PF_BATCH_TIMEOUT_SEC_DEFAULT from azure.ai.evaluation._user_agent import USER_AGENT -from azure.ai.evaluation.evaluate._batch_run_client import BatchRunContext, CodeClient, ProxyClient +from azure.ai.evaluation._evaluate._batch_run_client import BatchRunContext, CodeClient, ProxyClient @pytest.fixture diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_built_in_evaluator.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_built_in_evaluator.py index 4f56589689f..1bcb6a49251 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_built_in_evaluator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_built_in_evaluator.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from azure.ai.evaluation.evaluators import FluencyEvaluator +from azure.ai.evaluation import FluencyEvaluator async def fluency_async_mock(): diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_chat_evaluator.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_chat_evaluator.py index dcc2b0440fd..6e7eb55f3a5 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_chat_evaluator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_chat_evaluator.py @@ -1,6 +1,6 @@ import pytest -from azure.ai.evaluation.evaluators import ChatEvaluator +from azure.ai.evaluation import ChatEvaluator @pytest.mark.usefixtures("mock_model_config") diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_chat_evaluator.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_chat_evaluator.py index 9b739d21ee6..cd44b280937 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_chat_evaluator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_chat_evaluator.py @@ -1,6 +1,6 @@ import pytest -from azure.ai.evaluation.evaluators import ContentSafetyChatEvaluator +from azure.ai.evaluation import ContentSafetyChatEvaluator @pytest.mark.usefixtures("mock_project_scope") diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_defect_rate.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_defect_rate.py index e4417dfb567..7b4acd86efc 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_defect_rate.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_content_safety_defect_rate.py @@ -4,8 +4,8 @@ import pathlib import pandas as pd import pytest -from azure.ai.evaluation.evaluate._evaluate import _aggregate_metrics -from azure.ai.evaluation.evaluators import ContentSafetyEvaluator +from azure.ai.evaluation._evaluate._evaluate import _aggregate_metrics +from azure.ai.evaluation import ContentSafetyEvaluator def _get_file(name): diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_eval_run.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_eval_run.py index a45ad5a8ff4..55979668600 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_eval_run.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_eval_run.py @@ -9,8 +9,8 @@ import jwt import pytest from promptflow.azure._utils._token_cache import ArmTokenCache -import azure.ai.evaluation.evaluate._utils as ev_utils -from azure.ai.evaluation.evaluate._eval_run import EvalRun, RunStatus +import azure.ai.evaluation._evaluate._utils as ev_utils +from azure.ai.evaluation._evaluate._eval_run import EvalRun, RunStatus def generate_mock_token(): @@ -260,7 +260,7 @@ class TestEvalRun: kwargs = {"artifact_folder": tmp_path} else: kwargs = {"key": "f1", "value": 0.5} - with patch("azure.ai.evaluation.evaluate._eval_run.BlobServiceClient", return_value=MagicMock()): + with patch("azure.ai.evaluation._evaluate._eval_run.BlobServiceClient", return_value=MagicMock()): fn(**kwargs) assert len(caplog.records) == 1 assert mock_response.text() in caplog.records[0].message @@ -338,7 +338,7 @@ class TestEvalRun: ) as run: assert len(caplog.records) == 1 assert "The results will be saved locally, but will not be logged to Azure." in caplog.records[0].message - with patch("azure.ai.evaluation.evaluate._eval_run.EvalRun.request_with_retry") as mock_request: + with patch("azure.ai.evaluation._evaluate._eval_run.EvalRun.request_with_retry") as mock_request: run.log_artifact("mock_dir") run.log_metric("foo", 42) run.write_properties_to_run_history({"foo": "bar"}) diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate.py index 9562bf03d99..03f0c7dadbc 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate.py @@ -10,20 +10,20 @@ from pandas.testing import assert_frame_equal from promptflow.client import PFClient from azure.ai.evaluation._constants import DEFAULT_EVALUATION_RESULTS_FILE_NAME -from azure.ai.evaluation.evaluate import evaluate -from azure.ai.evaluation.evaluate._evaluate import ( +from azure.ai.evaluation._evaluate._evaluate import ( _aggregate_metrics, _apply_target_to_data, _rename_columns_conditionally, ) -from azure.ai.evaluation.evaluate._utils import _apply_column_mapping, _trace_destination_from_project_scope -from azure.ai.evaluation.evaluators import ( +from azure.ai.evaluation._evaluate._utils import _apply_column_mapping, _trace_destination_from_project_scope +from azure.ai.evaluation import ( + evaluate, ContentSafetyEvaluator, F1ScoreEvaluator, GroundednessEvaluator, ProtectedMaterialEvaluator, ) -from azure.ai.evaluation.evaluators._eci._eci import ECIEvaluator +from azure.ai.evaluation._evaluators._eci._eci import ECIEvaluator def _get_file(name): @@ -393,7 +393,7 @@ class TestEvaluate: expected.at[3, "outputs.yeti.result"] = np.nan assert_frame_equal(expected, result_df) - @patch("azure.ai.evaluation.evaluate._evaluate._evaluate") + @patch("azure.ai.evaluation._evaluate._evaluate._evaluate") def test_evaluate_main_entry_guard(self, mock_evaluate, evaluate_test_data_jsonl_file): err_msg = ( "An attempt has been made to start a new process before the\n " diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate_telemetry.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate_telemetry.py index a909e7d5c39..6f6540b32c2 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate_telemetry.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluate_telemetry.py @@ -9,8 +9,8 @@ import pandas as pd import pytest from promptflow.client import load_flow -from azure.ai.evaluation.evaluate._telemetry import log_evaluate_activity -from azure.ai.evaluation.evaluators import F1ScoreEvaluator, HateUnfairnessEvaluator +from azure.ai.evaluation._evaluate._telemetry import log_evaluate_activity +from azure.ai.evaluation import F1ScoreEvaluator, HateUnfairnessEvaluator def _add_nans(df, n, column_name): diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_jailbreak_simulator.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_jailbreak_simulator.py index ceb1621eb97..b7c6fcb263b 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_jailbreak_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_jailbreak_simulator.py @@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from azure.ai.evaluation.synthetic import AdversarialScenario, DirectAttackSimulator +from azure.ai.evaluation.simulator import AdversarialScenario, DirectAttackSimulator @pytest.fixture() @@ -22,12 +22,12 @@ def async_callback(): @pytest.mark.unittest class TestSimulator: - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async") - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies") def test_initialization_with_all_valid_scenarios( self, mock_ensure_service_dependencies, @@ -58,9 +58,9 @@ class TestSimulator: assert callable(simulator) simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback) - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) def test_simulator_raises_validation_error_with_unsupported_scenario( self, _get_content_harm_template_collections, _get_service_discovery_url @@ -84,12 +84,12 @@ class TestSimulator: ) ) - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async") - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies") def test_initialization_parity_with_evals( self, mock_ensure_service_dependencies, diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_save_eval.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_save_eval.py index 13fe5794017..95054b2dd19 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_save_eval.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_save_eval.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Type import pytest -from azure.ai.evaluation import evaluators +import azure.ai.evaluation as evaluators @pytest.fixture @@ -39,7 +39,7 @@ class TestSaveEval: def test_load_and_run_evaluators(self, tmpdir, pf_client, data_file) -> None: """Test regular evaluator saving.""" - from azure.ai.evaluation.evaluators import F1ScoreEvaluator + from azure.ai.evaluation import F1ScoreEvaluator pf_client.flows.save(F1ScoreEvaluator, path=tmpdir) run = pf_client.run(tmpdir, data=data_file) diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_simulator.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_simulator.py index f2fb0f676a0..8d46c5e8bce 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_simulator.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_simulator.py @@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from azure.ai.evaluation.synthetic import AdversarialScenario, AdversarialSimulator +from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator @pytest.fixture() @@ -22,12 +22,12 @@ def async_callback(): @pytest.mark.unittest class TestSimulator: - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async") - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies") def test_initialization_with_all_valid_scenarios( self, mock_ensure_service_dependencies, @@ -59,9 +59,9 @@ class TestSimulator: assert callable(simulator) simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback) - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) def test_simulator_raises_validation_error_with_unsupported_scenario( self, _get_content_harm_template_collections, _get_service_discovery_url @@ -86,12 +86,12 @@ class TestSimulator: ) ) - @patch("azure.ai.evaluation.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url") + @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url") @patch( - "azure.ai.evaluation.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" + "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections" ) - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async") - @patch("azure.ai.evaluation.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async") + @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies") def test_initialization_parity_with_evals( self, mock_ensure_service_dependencies, diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_callback_conv_bot.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_callback_conv_bot.py index 77b4edc7883..4f094603c66 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_callback_conv_bot.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_callback_conv_bot.py @@ -2,7 +2,7 @@ from unittest.mock import AsyncMock import pytest -from azure.ai.evaluation.synthetic._conversation import ( +from azure.ai.evaluation.simulator._conversation import ( CallbackConversationBot, ConversationRole, OpenAIChatCompletionsModel, diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_conversation_bot.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_conversation_bot.py index b05d9b71040..4d336018a81 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_conversation_bot.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_synthetic_conversation_bot.py @@ -5,7 +5,7 @@ import pytest from azure.core.pipeline.policies import AsyncRetryPolicy, RetryMode from azure.ai.evaluation._http_utils import get_async_http_client -from azure.ai.evaluation.synthetic._conversation import ( +from azure.ai.evaluation.simulator._conversation import ( ConversationBot, ConversationRole, ConversationTurn, From 186410e47ab8f3426bf1aa38ed2988aad3affd36 Mon Sep 17 00:00:00 2001 From: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:00:10 -0700 Subject: [PATCH 17/17] Update to newer resource id for open source api (#37452) Co-authored-by: Wes Haggard --- eng/common/scripts/Helpers/Metadata-Helpers.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/common/scripts/Helpers/Metadata-Helpers.ps1 b/eng/common/scripts/Helpers/Metadata-Helpers.ps1 index 1e169198159..0bc9b02072e 100644 --- a/eng/common/scripts/Helpers/Metadata-Helpers.ps1 +++ b/eng/common/scripts/Helpers/Metadata-Helpers.ps1 @@ -10,7 +10,7 @@ function Generate-AadToken ($TenantId, $ClientId, $ClientSecret) "grant_type" = "client_credentials" "client_id" = $ClientId "client_secret" = $ClientSecret - "resource" = "api://2789159d-8d8b-4d13-b90b-ca29c1707afd" + "resource" = "api://2efaf292-00a0-426c-ba7d-f5d2b214b8fc" } Write-Host "Generating aad token..." $resp = Invoke-RestMethod $LoginAPIBaseURI -Method 'POST' -Headers $headers -Body $body