Add tenant selection during adding connection using AAD auth (#17266)

* add tenant selection when adding a connection

* fix tenant selection in a few other places

* fix question test

* lint

* prompt only when there are multiple tenants

* comments
This commit is contained in:
Hai Cao 2022-02-21 23:21:48 -08:00 коммит произвёл GitHub
Родитель 703d393555
Коммит 8de14b74e7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 85 добавлений и 12 удалений

Просмотреть файл

@ -128,6 +128,12 @@
<trans-unit id="aad">
<source xml:lang="en">AAD</source>
</trans-unit>
<trans-unit id="azureChooseTenant">
<source xml:lang="en">Choose an Azure tenant</source>
</trans-unit>
<trans-unit id="tenant">
<source xml:lang="en">Tenant</source>
</trans-unit>
<trans-unit id="usernamePrompt">
<source xml:lang="en">User name</source>
</trans-unit>

Просмотреть файл

@ -19,6 +19,9 @@ import { ConnectionProfile } from '../models/connectionProfile';
import { AccountStore } from './accountStore';
import providerSettings from '../azure/providerSettings';
import VscodeWrapper from '../controllers/vscodeWrapper';
import { QuestionTypes, IQuestion, IPrompter, INameValueChoice } from '../prompts/question';
import { Tenant } from '@microsoft/ads-adal-library';
import { AzureAccount } from '../../lib/ads-adal-library/src';
function getAppDataPath(): string {
let platform = process.platform;
@ -72,11 +75,13 @@ export class AzureController {
private storageService: StorageService;
private context: vscode.ExtensionContext;
private logger: AzureLogger;
private prompter: IPrompter;
private _vscodeWrapper: VscodeWrapper;
private credentialStoreInitialized = false;
constructor(context: vscode.ExtensionContext, logger?: AzureLogger) {
constructor(context: vscode.ExtensionContext, prompter: IPrompter, logger?: AzureLogger) {
this.context = context;
this.prompter = prompter;
if (!this.logger) {
this.logger = new AzureLogger();
}
@ -93,6 +98,25 @@ export class AzureController {
this.azureMessageDisplayer = new AzureMessageDisplayer();
}
private async promptForTenantChoice(account: AzureAccount, profile: ConnectionProfile): Promise<void> {
let tenantChoices: INameValueChoice[] = account.properties.tenants?.map(t => ({ name: t.displayName, value: t }));
if (tenantChoices && tenantChoices.length === 1) {
profile.tenantId = tenantChoices[0].value.id;
return;
}
let tenantQuestion: IQuestion = {
type: QuestionTypes.expand,
name: LocalizedConstants.tenant,
message: LocalizedConstants.azureChooseTenant,
choices: tenantChoices,
shouldPrompt: (answers) => profile.isAzureActiveDirectory() && tenantChoices.length > 1,
onAnswered: (value: Tenant) => {
profile.tenantId = value.id;
}
};
await this.prompter.promptSingle(tenantQuestion, true);
}
public async getTokens(profile: ConnectionProfile, accountStore: AccountStore, settings: AADResource): Promise<ConnectionProfile> {
let account: IAccount;
let config = vscode.workspace.getConfiguration('mssql').get('azureActiveDirectory');
@ -100,8 +124,12 @@ export class AzureController {
let azureCodeGrant = await this.createAuthCodeGrant();
account = await azureCodeGrant.startLogin();
await accountStore.addAccount(account);
if (!profile.tenantId) {
await this.promptForTenantChoice(account, profile);
}
let tid = profile.tenantId ? profile.tenantId : azureCodeGrant.getHomeTenant(account).id;
const token = await azureCodeGrant.getAccountSecurityToken(
account, azureCodeGrant.getHomeTenant(account).id, settings
account, tid, settings
);
if (!token) {
let errorMessage = LocalizedConstants.msgGetTokenFail;
@ -115,8 +143,12 @@ export class AzureController {
let azureDeviceCode = await this.createDeviceCode();
account = await azureDeviceCode.startLogin();
await accountStore.addAccount(account);
if (!profile.tenantId) {
await this.promptForTenantChoice(account, profile);
}
let tid = profile.tenantId ? profile.tenantId : azureDeviceCode.getHomeTenant(account).id;
const token = await azureDeviceCode.getAccountSecurityToken(
account, azureDeviceCode.getHomeTenant(account).id, settings
account, tid, settings
);
if (!token) {
let errorMessage = LocalizedConstants.msgGetTokenFail;
@ -136,7 +168,7 @@ export class AzureController {
await this._vscodeWrapper.showErrorMessage(LocalizedConstants.msgAccountNotFound);
throw new Error(LocalizedConstants.msgAccountNotFound);
}
let azureAccountToken = await this.refreshToken(account, accountStore, settings);
let azureAccountToken = await this.refreshToken(account, accountStore, settings, profile.tenantId);
if (!azureAccountToken) {
let errorMessage = LocalizedConstants.msgAccountRefreshFailed;
return this._vscodeWrapper.showErrorMessage(errorMessage, LocalizedConstants.refreshTokenLabel).then(async result => {
@ -155,7 +187,7 @@ export class AzureController {
return profile;
}
public async refreshToken(account: IAccount, accountStore: AccountStore, settings: AADResource): Promise<Token | undefined> {
public async refreshToken(account: IAccount, accountStore: AccountStore, settings: AADResource, tenantId: string = undefined): Promise<Token | undefined> {
try {
let token: Token;
if (account.properties.azureAuthType === 0) {
@ -166,7 +198,8 @@ export class AzureController {
return undefined;
}
await accountStore.addAccount(newAccount);
token = await azureCodeGrant.getAccountSecurityToken(account, azureCodeGrant.getHomeTenant(account).id, settings);
let tid = tenantId ? tenantId : azureCodeGrant.getHomeTenant(account).id;
token = await azureCodeGrant.getAccountSecurityToken(account, tid, settings);
} else if (account.properties.azureAuthType === 1) {
// Auth Device Code
let azureDeviceCode = await this.createDeviceCode();
@ -175,8 +208,9 @@ export class AzureController {
if (newAccount.isStale === true) {
return undefined;
}
let tid = tenantId ? tenantId : azureDeviceCode.getHomeTenant(account).id;
token = await azureDeviceCode.getAccountSecurityToken(
account, azureDeviceCode.getHomeTenant(account).id, providerSettings.resources.databaseResource);
account, tid, providerSettings.resources.databaseResource);
}
return token;
} catch (ex) {

Просмотреть файл

@ -120,7 +120,7 @@ export default class ConnectionManager {
}
if (!this.azureController) {
this.azureController = new AzureController(context);
this.azureController = new AzureController(context, prompter);
this.azureController.init();
}
@ -707,7 +707,7 @@ export default class ConnectionManager {
if (!connectionCreds.azureAccountToken || connectionCreds.expiresOn - currentTime < maxTolerance) {
let account = this.accountStore.getAccount(connectionCreds.accountId);
let profile = new ConnectionProfile(connectionCreds);
let azureAccountToken = await this.azureController.refreshToken(account, this.accountStore, providerSettings.resources.databaseResource);
let azureAccountToken = await this.azureController.refreshToken(account, this.accountStore, providerSettings.resources.databaseResource, profile.tenantId);
if (!azureAccountToken) {
let errorMessage = LocalizedConstants.msgAccountRefreshFailed;
let refreshResult = await this.vscodeWrapper.showErrorMessage(errorMessage, LocalizedConstants.refreshTokenLabel);

Просмотреть файл

@ -19,6 +19,7 @@ export class ConnectionCredentials implements IConnectionInfo {
public password: string;
public email: string | undefined;
public accountId: string | undefined;
public tenantId: string | undefined;
public port: number;
public authenticationType: string;
public azureAccountToken: string | undefined;

Просмотреть файл

@ -15,6 +15,7 @@ import { AzureController } from '../azure/azureController';
import { AccountStore } from '../azure/accountStore';
import { IAccount } from './contracts/azure/accountInterfaces';
import providerSettings from '../azure/providerSettings';
import { Tenant, AzureAccount } from '@microsoft/ads-adal-library';
// Concrete implementation of the IConnectionProfile interface
@ -30,11 +31,13 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
public expiresOn: number | undefined;
public accountStore: AccountStore;
public accountId: string;
public tenantId: string;
constructor(connectionCredentials?: ConnectionCredentials) {
super();
if (connectionCredentials) {
this.accountId = connectionCredentials.accountId;
this.tenantId = connectionCredentials.tenantId;
this.authenticationType = connectionCredentials.authenticationType;
this.azureAccountToken = connectionCredentials.azureAccountToken;
this.expiresOn = connectionCredentials.expiresOn;
@ -68,6 +71,7 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
let azureAccountChoices: INameValueChoice[] = ConnectionProfile.getAccountChoices(accountStore);
let accountAnswer: IAccount;
azureAccountChoices.unshift({ name: LocalizedConstants.azureAddAccount, value: 'addAccount' });
let tenantChoices: INameValueChoice[] = [];
let questions: IQuestion[] = await ConnectionCredentials.getRequiredCredentialValuesQuestions(profile, true,
@ -88,7 +92,26 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
message: LocalizedConstants.azureChooseAccount,
choices: azureAccountChoices,
shouldPrompt: (answers) => profile.isAzureActiveDirectory(),
onAnswered: (value: IAccount) => accountAnswer = value
onAnswered: (value) => {
accountAnswer = value;
if (value !== 'addAccount') {
let account: AzureAccount = value;
tenantChoices.push(...account?.properties?.tenants.map(t => ({ name: t.displayName, value: t })));
if (tenantChoices.length === 1) {
profile.tenantId = tenantChoices[0].value.id;
}
}
}
},
{
type: QuestionTypes.expand,
name: LocalizedConstants.tenant,
message: LocalizedConstants.azureChooseTenant,
choices: tenantChoices,
shouldPrompt: (answers) => profile.isAzureActiveDirectory() && tenantChoices.length > 1,
onAnswered: (value: Tenant) => {
profile.tenantId = value.id;
}
},
{
type: QuestionTypes.input,

Просмотреть файл

@ -439,7 +439,7 @@ export class ObjectExplorerService {
let profile = new ConnectionProfile(connectionCredentials);
if (!connectionCredentials.azureAccountToken) {
let azureAccountToken = await azureController.refreshToken(
account, this._connectionManager.accountStore, providerSettings.resources.databaseResource);
account, this._connectionManager.accountStore, providerSettings.resources.databaseResource, connectionCredentials.tenantId);
if (!azureAccountToken) {
let errorMessage = LocalizedConstants.msgAccountRefreshFailed;
await this._connectionManager.vscodeWrapper.showErrorMessage(

Просмотреть файл

@ -28,6 +28,7 @@ function createTestCredentials(): IConnectionInfo {
password: '12345678',
email: 'test-email',
accountId: 'test-account-id',
tenantId: 'test-tenant-id',
port: 1234,
authenticationType: AuthenticationTypes[AuthenticationTypes.SqlLogin],
azureAccountToken: '',
@ -63,6 +64,7 @@ suite('Connection Profile tests', () => {
let mockAccountStore: AccountStore;
let mockAzureController: AzureController;
let mockContext: TypeMoq.IMock<vscode.ExtensionContext>;
let mockPrompter: TypeMoq.IMock<IPrompter>;
let globalstate: TypeMoq.IMock<vscode.Memento>;
setup(() => {
@ -70,7 +72,7 @@ suite('Connection Profile tests', () => {
globalstate = TypeMoq.Mock.ofType<vscode.Memento>();
mockContext = TypeMoq.Mock.ofType<vscode.ExtensionContext>();
mockContext.setup(c => c.workspaceState).returns(() => globalstate.object);
mockAzureController = new AzureController(mockContext.object);
mockAzureController = new AzureController(mockContext.object, mockPrompter.object);
mockAccountStore = new AccountStore(mockContext.object);
});
@ -104,6 +106,7 @@ suite('Connection Profile tests', () => {
LocalizedConstants.passwordPrompt, // Password
LocalizedConstants.msgSavePassword, // Save Password
LocalizedConstants.aad, // Choose AAD Account
LocalizedConstants.tenant, // Choose AAD Tenant
LocalizedConstants.profileNamePrompt // Profile Name
];

Просмотреть файл

@ -50,6 +50,7 @@ function createTestCredentials(): IConnectionInfo {
password: '12345678',
email: 'test-email',
accountId: 'test-account-id',
tenantId: 'test-tenant-id',
port: 1234,
authenticationType: AuthenticationTypes[AuthenticationTypes.SqlLogin],
azureAccountToken: '',

5
typings/vscode-mssql.d.ts поставляемый
Просмотреть файл

@ -120,6 +120,11 @@ declare module 'vscode-mssql' {
*/
accountId: string | undefined;
/**
* tenantId
*/
tenantId: string | undefined;
/**
* The port number to connect to.
*/