зеркало из https://github.com/Azure/ARO-RP.git
portal initial commit
This commit is contained in:
Родитель
c38da832db
Коммит
9e5c4f8930
|
@ -24,6 +24,7 @@ func usage() {
|
|||
fmt.Fprintf(flag.CommandLine.Output(), " %s deploy config.yaml location\n", os.Args[0])
|
||||
fmt.Fprintf(flag.CommandLine.Output(), " %s mirror [release_image...]\n", os.Args[0])
|
||||
fmt.Fprintf(flag.CommandLine.Output(), " %s monitor\n", os.Args[0])
|
||||
fmt.Fprintf(flag.CommandLine.Output(), " %s portal\n", os.Args[0])
|
||||
fmt.Fprintf(flag.CommandLine.Output(), " %s rp\n", os.Args[0])
|
||||
fmt.Fprintf(flag.CommandLine.Output(), " %s operator {master,worker}\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
|
@ -46,6 +47,9 @@ func main() {
|
|||
|
||||
var err error
|
||||
switch strings.ToLower(flag.Arg(0)) {
|
||||
case "deploy":
|
||||
checkArgs(3)
|
||||
err = deploy(ctx, log)
|
||||
case "mirror":
|
||||
checkMinArgs(1)
|
||||
err = mirror(ctx, log)
|
||||
|
@ -55,9 +59,9 @@ func main() {
|
|||
case "rp":
|
||||
checkArgs(1)
|
||||
err = rp(ctx, log)
|
||||
case "deploy":
|
||||
checkArgs(3)
|
||||
err = deploy(ctx, log)
|
||||
case "portal":
|
||||
checkArgs(1)
|
||||
err = portal(ctx, log)
|
||||
case "operator":
|
||||
checkArgs(2)
|
||||
err = operator(ctx, log)
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
package main
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/deploy/generator"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics/noop"
|
||||
pkgportal "github.com/Azure/ARO-RP/pkg/portal"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
"github.com/Azure/ARO-RP/pkg/util/keyvault"
|
||||
)
|
||||
|
||||
func portal(ctx context.Context, log *logrus.Entry) error {
|
||||
_env, err := env.NewCore(ctx, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _env.DeploymentMode() != deployment.Development {
|
||||
for _, key := range []string{
|
||||
"PORTAL_HOSTNAME",
|
||||
} {
|
||||
if _, found := os.LookupEnv(key); !found {
|
||||
return fmt.Errorf("environment variable %q unset", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range []string{
|
||||
"AZURE_PORTAL_CLIENT_ID",
|
||||
"AZURE_PORTAL_ACCESS_GROUP_IDS",
|
||||
"AZURE_PORTAL_ELEVATED_GROUP_IDS",
|
||||
} {
|
||||
if _, found := os.LookupEnv(key); !found {
|
||||
return fmt.Errorf("environment variable %q unset", key)
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs, err := parseGroupIDs(os.Getenv("AZURE_PORTAL_ACCESS_GROUP_IDS"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elevatedGroupIDs, err := parseGroupIDs(os.Getenv("AZURE_PORTAL_ELEVATED_GROUP_IDS"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rpKVAuthorizer, err := _env.NewRPAuthorizer(_env.Environment().ResourceIdentifiers.KeyVault)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: should not be using the service keyvault here
|
||||
serviceKeyvaultURI, err := keyvault.URI(_env, generator.ServiceKeyvaultSuffix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serviceKeyvault := keyvault.NewManager(rpKVAuthorizer, serviceKeyvaultURI)
|
||||
|
||||
key, err := serviceKeyvault.GetBase64Secret(ctx, env.EncryptionSecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbc, err := database.NewDatabaseClient(ctx, log.WithField("component", "database"), _env, &noop.Noop{}, cipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbOpenShiftClusters, err := database.NewOpenShiftClusters(ctx, _env.DeploymentMode(), dbc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbPortal, err := database.NewPortal(ctx, _env.DeploymentMode(), dbc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
portalKeyvaultURI, err := keyvault.URI(_env, generator.PortalKeyvaultSuffix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
portalKeyvault := keyvault.NewManager(rpKVAuthorizer, portalKeyvaultURI)
|
||||
|
||||
servingKey, servingCerts, err := portalKeyvault.GetCertificateSecret(ctx, env.PortalServerSecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientKey, clientCerts, err := portalKeyvault.GetCertificateSecret(ctx, env.PortalServerClientSecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionKey, err := portalKeyvault.GetBase64Secret(ctx, env.PortalServerSessionKeySecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b, err := portalKeyvault.GetBase64Secret(ctx, env.PortalServerSSHKeySecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshKey, err := x509.ParsePKCS1PrivateKey(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dialer, err := proxy.NewDialer(_env.DeploymentMode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientID := os.Getenv("AZURE_PORTAL_CLIENT_ID")
|
||||
verifier, err := middleware.NewVerifier(ctx, _env.TenantID(), clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hostname := "localhost:8444"
|
||||
address := "localhost:8444"
|
||||
sshAddress := "localhost:2222"
|
||||
if _env.DeploymentMode() != deployment.Development {
|
||||
hostname = os.Getenv("PORTAL_HOSTNAME")
|
||||
address = ":8444"
|
||||
sshAddress = ":2222"
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshl, err := net.Listen("tcp", sshAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Print("listening")
|
||||
|
||||
p := pkgportal.NewPortal(_env, log.WithField("component", "portal"), log.WithField("component", "portal-access"), l, sshl, verifier, hostname, servingKey, servingCerts, clientID, clientKey, clientCerts, sessionKey, sshKey, groupIDs, elevatedGroupIDs, dbOpenShiftClusters, dbPortal, dialer)
|
||||
|
||||
return p.Run(ctx)
|
||||
}
|
||||
|
||||
func parseGroupIDs(_groupIDs string) ([]string, error) {
|
||||
groupIDs := strings.Split(_groupIDs, ",")
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := uuid.FromString(groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
|
@ -129,6 +129,28 @@
|
|||
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), parameters('databaseName'))]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"resource": {
|
||||
"id": "Portal",
|
||||
"partitionKey": {
|
||||
"paths": [
|
||||
"/id"
|
||||
],
|
||||
"kind": "Hash"
|
||||
},
|
||||
"defaultTtl": -1
|
||||
},
|
||||
"options": {}
|
||||
},
|
||||
"name": "[concat(parameters('databaseAccountName'), '/', parameters('databaseName'), '/Portal')]",
|
||||
"type": "Microsoft.DocumentDB/databaseAccounts/sqlDatabases/containers",
|
||||
"location": "[resourceGroup().location]",
|
||||
"apiVersion": "2019-08-01",
|
||||
"dependsOn": [
|
||||
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), parameters('databaseName'))]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"resource": {
|
||||
|
|
|
@ -116,6 +116,47 @@
|
|||
"location": "[resourceGroup().location]",
|
||||
"apiVersion": "2016-10-01"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
"sku": {
|
||||
"family": "A",
|
||||
"name": "standard"
|
||||
},
|
||||
"accessPolicies": [
|
||||
{
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
"objectId": "[parameters('rpServicePrincipalId')]",
|
||||
"permissions": {
|
||||
"secrets": [
|
||||
"get"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
"objectId": "[parameters('adminObjectId')]",
|
||||
"permissions": {
|
||||
"secrets": [
|
||||
"set",
|
||||
"list"
|
||||
],
|
||||
"certificates": [
|
||||
"delete",
|
||||
"get",
|
||||
"import",
|
||||
"list"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"enableSoftDelete": true
|
||||
},
|
||||
"name": "[concat(parameters('keyvaultPrefix'), '-por')]",
|
||||
"type": "Microsoft.KeyVault/vaults",
|
||||
"location": "[resourceGroup().location]",
|
||||
"apiVersion": "2016-10-01"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
|
|
|
@ -38,6 +38,15 @@
|
|||
"mdsdEnvironment": {
|
||||
"value": ""
|
||||
},
|
||||
"portalAccessGroupIds": {
|
||||
"value": ""
|
||||
},
|
||||
"portalClientId": {
|
||||
"value": ""
|
||||
},
|
||||
"portalElevatedGroupIds": {
|
||||
"value": ""
|
||||
},
|
||||
"rpImage": {
|
||||
"value": ""
|
||||
},
|
||||
|
|
|
@ -8,6 +8,9 @@
|
|||
"extraClusterKeyvaultAccessPolicies": {
|
||||
"value": []
|
||||
},
|
||||
"extraPortalKeyvaultAccessPolicies": {
|
||||
"value": []
|
||||
},
|
||||
"extraServiceKeyvaultAccessPolicies": {
|
||||
"value": []
|
||||
},
|
||||
|
|
|
@ -34,6 +34,17 @@
|
|||
}
|
||||
}
|
||||
],
|
||||
"portalKeyvaultAccessPolicies": [
|
||||
{
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
"objectId": "[parameters('rpServicePrincipalId')]",
|
||||
"permissions": {
|
||||
"secrets": [
|
||||
"get"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"serviceKeyvaultAccessPolicies": [
|
||||
{
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
|
@ -55,6 +66,10 @@
|
|||
"type": "array",
|
||||
"defaultValue": []
|
||||
},
|
||||
"extraPortalKeyvaultAccessPolicies": {
|
||||
"type": "array",
|
||||
"defaultValue": []
|
||||
},
|
||||
"extraServiceKeyvaultAccessPolicies": {
|
||||
"type": "array",
|
||||
"defaultValue": []
|
||||
|
@ -140,6 +155,22 @@
|
|||
"condition": "[parameters('fullDeploy')]",
|
||||
"apiVersion": "2016-10-01"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
"sku": {
|
||||
"family": "A",
|
||||
"name": "standard"
|
||||
},
|
||||
"accessPolicies": "[concat(variables('portalKeyvaultAccessPolicies'), parameters('extraPortalKeyvaultAccessPolicies'))]",
|
||||
"enableSoftDelete": true
|
||||
},
|
||||
"name": "[concat(parameters('keyvaultPrefix'), '-por')]",
|
||||
"type": "Microsoft.KeyVault/vaults",
|
||||
"location": "[resourceGroup().location]",
|
||||
"condition": "[parameters('fullDeploy')]",
|
||||
"apiVersion": "2016-10-01"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"tenantId": "[subscription().tenantId]",
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -182,6 +182,31 @@ locations.
|
|||
>/dev/null
|
||||
```
|
||||
|
||||
1. Create an AAD application which will fake up the portal client.
|
||||
|
||||
This application requires client certificate authentication to be enabled. A
|
||||
suitable key/certificate file can be generated using the following helper
|
||||
utility:
|
||||
|
||||
```bash
|
||||
go run ./hack/genkey -client portal-client
|
||||
mv portal-client.* secrets
|
||||
```
|
||||
|
||||
```bash
|
||||
AZURE_PORTAL_CLIENT_ID="$(az ad app create \
|
||||
--display-name aro-v4-portal-shared \
|
||||
--identifier-uris "https://$(uuidgen)/" \
|
||||
--reply-urls "https://localhost:8444/callback" \
|
||||
--query appId \
|
||||
-o tsv)"
|
||||
az ad app credential reset \
|
||||
--id "$AZURE_PORTAL_CLIENT_ID" \
|
||||
--cert "$(base64 -w0 <secrets/portal-client.crt)" >/dev/null
|
||||
```
|
||||
|
||||
TODO: more steps are needed to configure aro-v4-portal-shared.
|
||||
|
||||
|
||||
## Certificates
|
||||
|
||||
|
@ -258,6 +283,9 @@ locations.
|
|||
export AZURE_ARM_CLIENT_ID='$AZURE_ARM_CLIENT_ID'
|
||||
export AZURE_ARM_CLIENT_SECRET='$AZURE_ARM_CLIENT_SECRET'
|
||||
export AZURE_FP_CLIENT_ID='$AZURE_FP_CLIENT_ID'
|
||||
export AZURE_PORTAL_CLIENT_ID='$AZURE_PORTAL_CLIENT_ID'
|
||||
export AZURE_PORTAL_ACCESS_GROUP_IDS='$ADMIN_OBJECT_ID'
|
||||
export AZURE_PORTAL_ELEVATED_GROUP_IDS='$ADMIN_OBJECT_ID'
|
||||
export AZURE_CLIENT_ID='$AZURE_CLIENT_ID'
|
||||
export AZURE_CLIENT_SECRET='$AZURE_CLIENT_SECRET'
|
||||
export AZURE_RP_CLIENT_ID='$AZURE_RP_CLIENT_ID'
|
||||
|
|
|
@ -84,6 +84,14 @@ import_certs_secrets() {
|
|||
--vault-name "$KEYVAULT_PREFIX-svc" \
|
||||
--name rp-server \
|
||||
--file secrets/localhost.pem
|
||||
az keyvault certificate import \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--name portal-server \
|
||||
--file secrets/localhost.pem
|
||||
az keyvault certificate import \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--name portal-client \
|
||||
--file secrets/portal-client.pem
|
||||
az keyvault certificate import \
|
||||
--vault-name "$KEYVAULT_PREFIX-svc" \
|
||||
--name cluster-mdsd \
|
||||
|
@ -104,6 +112,22 @@ import_certs_secrets() {
|
|||
--vault-name "$KEYVAULT_PREFIX-svc" \
|
||||
--name fe-encryption-key \
|
||||
--value "$(openssl rand -base64 32)"
|
||||
az keyvault secret list \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--query '[].name' \
|
||||
-o tsv | grep -q ^portal-session-key$ || \
|
||||
az keyvault secret set \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--name portal-session-key \
|
||||
--value "$(openssl rand -base64 32)"
|
||||
az keyvault secret list \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--query '[].name' \
|
||||
-o tsv | grep -q ^portal-sshkey$ || \
|
||||
az keyvault secret set \
|
||||
--vault-name "$KEYVAULT_PREFIX-por" \
|
||||
--name portal-sshkey \
|
||||
--value "$(openssl genpkey -algorithm rsa -pkeyopt rsa_keygen_bits:2048 -outform der | base64 -w0)"
|
||||
}
|
||||
|
||||
update_parent_domain_dns_zone() {
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package api
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
// Portal represents a portal
|
||||
type Portal struct {
|
||||
MissingFields
|
||||
|
||||
Username string `json:"username"`
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
SSH *SSH `json:"ssh,omitempty"`
|
||||
Kubeconfig *Kubeconfig `json:"kubeconfig,omitempty"`
|
||||
}
|
||||
|
||||
type SSH struct {
|
||||
MissingFields
|
||||
|
||||
Master int `json:"master,omitempty"`
|
||||
}
|
||||
|
||||
type Kubeconfig struct {
|
||||
MissingFields
|
||||
|
||||
Elevated bool `json:"elevated,omitempty"`
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package api
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
// PortalDocuments represents portal documents.
|
||||
// pkg/database/cosmosdb requires its definition.
|
||||
type PortalDocuments struct {
|
||||
Count int `json:"_count,omitempty"`
|
||||
ResourceID string `json:"_rid,omitempty"`
|
||||
PortalDocuments []*PortalDocument `json:"Documents,omitempty"`
|
||||
}
|
||||
|
||||
func (c *PortalDocuments) String() string {
|
||||
return encodeJSON(c)
|
||||
}
|
||||
|
||||
// PortalDocument represents a portal document.
|
||||
// pkg/database/cosmosdb requires its definition.
|
||||
type PortalDocument struct {
|
||||
MissingFields
|
||||
|
||||
ID string `json:"id,omitempty"`
|
||||
ResourceID string `json:"_rid,omitempty"`
|
||||
Timestamp int `json:"_ts,omitempty"`
|
||||
Self string `json:"_self,omitempty"`
|
||||
ETag string `json:"_etag,omitempty"`
|
||||
Attachments string `json:"_attachments,omitempty"`
|
||||
TTL int `json:"ttl,omitempty"`
|
||||
LSN int `json:"_lsn,omitempty"`
|
||||
Metadata map[string]interface{} `json:"_metadata,omitempty"`
|
||||
|
||||
Portal *Portal `json:"portal,omitempty"`
|
||||
}
|
||||
|
||||
func (c *PortalDocument) String() string {
|
||||
return encodeJSON(c)
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
// Regular expressions used to validate the format of resource names and IDs acceptable by API.
|
||||
var (
|
||||
RxClusterID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]/providers/Microsoft\.RedHatOpenShift/openShiftClusters/[-a-z0-9_().]{0,89}[-a-z0-9_()]$`)
|
||||
RxResourceGroupID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]$`)
|
||||
RxSubnetID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]/providers/Microsoft\.Network/virtualNetworks/[-a-z0-9_.]{2,64}/subnets/[-a-z0-9_.]{2,80}$`)
|
||||
RxDomainName = regexp.MustCompile(`^` +
|
||||
|
|
|
@ -257,5 +257,5 @@ func (m *manager) saveGraph(ctx context.Context, g graph) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return graph.CreateBlockBlobFromReader(bytes.NewReader([]byte(output)), nil)
|
||||
return graph.CreateBlockBlobFromReader(bytes.NewReader(output), nil)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,313 @@
|
|||
// Code generated by github.com/jim-minter/go-cosmosdb, DO NOT EDIT.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
pkg "github.com/Azure/ARO-RP/pkg/api"
|
||||
)
|
||||
|
||||
type portalDocumentClient struct {
|
||||
*databaseClient
|
||||
path string
|
||||
}
|
||||
|
||||
// PortalDocumentClient is a portalDocument client
|
||||
type PortalDocumentClient interface {
|
||||
Create(context.Context, string, *pkg.PortalDocument, *Options) (*pkg.PortalDocument, error)
|
||||
List(*Options) PortalDocumentIterator
|
||||
ListAll(context.Context, *Options) (*pkg.PortalDocuments, error)
|
||||
Get(context.Context, string, string, *Options) (*pkg.PortalDocument, error)
|
||||
Replace(context.Context, string, *pkg.PortalDocument, *Options) (*pkg.PortalDocument, error)
|
||||
Delete(context.Context, string, *pkg.PortalDocument, *Options) error
|
||||
Query(string, *Query, *Options) PortalDocumentRawIterator
|
||||
QueryAll(context.Context, string, *Query, *Options) (*pkg.PortalDocuments, error)
|
||||
ChangeFeed(*Options) PortalDocumentIterator
|
||||
}
|
||||
|
||||
type portalDocumentChangeFeedIterator struct {
|
||||
*portalDocumentClient
|
||||
continuation string
|
||||
options *Options
|
||||
}
|
||||
|
||||
type portalDocumentListIterator struct {
|
||||
*portalDocumentClient
|
||||
continuation string
|
||||
done bool
|
||||
options *Options
|
||||
}
|
||||
|
||||
type portalDocumentQueryIterator struct {
|
||||
*portalDocumentClient
|
||||
partitionkey string
|
||||
query *Query
|
||||
continuation string
|
||||
done bool
|
||||
options *Options
|
||||
}
|
||||
|
||||
// PortalDocumentIterator is a portalDocument iterator
|
||||
type PortalDocumentIterator interface {
|
||||
Next(context.Context, int) (*pkg.PortalDocuments, error)
|
||||
Continuation() string
|
||||
}
|
||||
|
||||
// PortalDocumentRawIterator is a portalDocument raw iterator
|
||||
type PortalDocumentRawIterator interface {
|
||||
PortalDocumentIterator
|
||||
NextRaw(context.Context, int, interface{}) error
|
||||
}
|
||||
|
||||
// NewPortalDocumentClient returns a new portalDocument client
|
||||
func NewPortalDocumentClient(collc CollectionClient, collid string) PortalDocumentClient {
|
||||
return &portalDocumentClient{
|
||||
databaseClient: collc.(*collectionClient).databaseClient,
|
||||
path: collc.(*collectionClient).path + "/colls/" + collid,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) all(ctx context.Context, i PortalDocumentIterator) (*pkg.PortalDocuments, error) {
|
||||
allportalDocuments := &pkg.PortalDocuments{}
|
||||
|
||||
for {
|
||||
portalDocuments, err := i.Next(ctx, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if portalDocuments == nil {
|
||||
break
|
||||
}
|
||||
|
||||
allportalDocuments.Count += portalDocuments.Count
|
||||
allportalDocuments.ResourceID = portalDocuments.ResourceID
|
||||
allportalDocuments.PortalDocuments = append(allportalDocuments.PortalDocuments, portalDocuments.PortalDocuments...)
|
||||
}
|
||||
|
||||
return allportalDocuments, nil
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) Create(ctx context.Context, partitionkey string, newportalDocument *pkg.PortalDocument, options *Options) (portalDocument *pkg.PortalDocument, err error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
|
||||
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
options.NoETag = true
|
||||
|
||||
err = c.setOptions(options, newportalDocument, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.do(ctx, http.MethodPost, c.path+"/docs", "docs", c.path, http.StatusCreated, &newportalDocument, &portalDocument, headers)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) List(options *Options) PortalDocumentIterator {
|
||||
continuation := ""
|
||||
if options != nil {
|
||||
continuation = options.Continuation
|
||||
}
|
||||
|
||||
return &portalDocumentListIterator{portalDocumentClient: c, options: options, continuation: continuation}
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) ListAll(ctx context.Context, options *Options) (*pkg.PortalDocuments, error) {
|
||||
return c.all(ctx, c.List(options))
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) Get(ctx context.Context, partitionkey, portalDocumentid string, options *Options) (portalDocument *pkg.PortalDocument, err error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
|
||||
|
||||
err = c.setOptions(options, nil, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.do(ctx, http.MethodGet, c.path+"/docs/"+portalDocumentid, "docs", c.path+"/docs/"+portalDocumentid, http.StatusOK, nil, &portalDocument, headers)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) Replace(ctx context.Context, partitionkey string, newportalDocument *pkg.PortalDocument, options *Options) (portalDocument *pkg.PortalDocument, err error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
|
||||
|
||||
err = c.setOptions(options, newportalDocument, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.do(ctx, http.MethodPut, c.path+"/docs/"+newportalDocument.ID, "docs", c.path+"/docs/"+newportalDocument.ID, http.StatusOK, &newportalDocument, &portalDocument, headers)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) Delete(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (err error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
|
||||
|
||||
err = c.setOptions(options, portalDocument, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.do(ctx, http.MethodDelete, c.path+"/docs/"+portalDocument.ID, "docs", c.path+"/docs/"+portalDocument.ID, http.StatusNoContent, nil, nil, headers)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) Query(partitionkey string, query *Query, options *Options) PortalDocumentRawIterator {
|
||||
continuation := ""
|
||||
if options != nil {
|
||||
continuation = options.Continuation
|
||||
}
|
||||
|
||||
return &portalDocumentQueryIterator{portalDocumentClient: c, partitionkey: partitionkey, query: query, options: options, continuation: continuation}
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) QueryAll(ctx context.Context, partitionkey string, query *Query, options *Options) (*pkg.PortalDocuments, error) {
|
||||
return c.all(ctx, c.Query(partitionkey, query, options))
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) ChangeFeed(options *Options) PortalDocumentIterator {
|
||||
continuation := ""
|
||||
if options != nil {
|
||||
continuation = options.Continuation
|
||||
}
|
||||
|
||||
return &portalDocumentChangeFeedIterator{portalDocumentClient: c, options: options, continuation: continuation}
|
||||
}
|
||||
|
||||
func (c *portalDocumentClient) setOptions(options *Options, portalDocument *pkg.PortalDocument, headers http.Header) error {
|
||||
if options == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if portalDocument != nil && !options.NoETag {
|
||||
if portalDocument.ETag == "" {
|
||||
return ErrETagRequired
|
||||
}
|
||||
headers.Set("If-Match", portalDocument.ETag)
|
||||
}
|
||||
if len(options.PreTriggers) > 0 {
|
||||
headers.Set("X-Ms-Documentdb-Pre-Trigger-Include", strings.Join(options.PreTriggers, ","))
|
||||
}
|
||||
if len(options.PostTriggers) > 0 {
|
||||
headers.Set("X-Ms-Documentdb-Post-Trigger-Include", strings.Join(options.PostTriggers, ","))
|
||||
}
|
||||
if len(options.PartitionKeyRangeID) > 0 {
|
||||
headers.Set("X-Ms-Documentdb-PartitionKeyRangeID", options.PartitionKeyRangeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *portalDocumentChangeFeedIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
|
||||
headers := http.Header{}
|
||||
headers.Set("A-IM", "Incremental feed")
|
||||
|
||||
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
|
||||
if i.continuation != "" {
|
||||
headers.Set("If-None-Match", i.continuation)
|
||||
}
|
||||
|
||||
err = i.setOptions(i.options, nil, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = i.do(ctx, http.MethodGet, i.path+"/docs", "docs", i.path, http.StatusOK, nil, &portalDocuments, headers)
|
||||
if IsErrorStatusCode(err, http.StatusNotModified) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
i.continuation = headers.Get("Etag")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (i *portalDocumentChangeFeedIterator) Continuation() string {
|
||||
return i.continuation
|
||||
}
|
||||
|
||||
func (i *portalDocumentListIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
|
||||
if i.done {
|
||||
return
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
|
||||
if i.continuation != "" {
|
||||
headers.Set("X-Ms-Continuation", i.continuation)
|
||||
}
|
||||
|
||||
err = i.setOptions(i.options, nil, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = i.do(ctx, http.MethodGet, i.path+"/docs", "docs", i.path, http.StatusOK, nil, &portalDocuments, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
i.continuation = headers.Get("X-Ms-Continuation")
|
||||
i.done = i.continuation == ""
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (i *portalDocumentListIterator) Continuation() string {
|
||||
return i.continuation
|
||||
}
|
||||
|
||||
func (i *portalDocumentQueryIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
|
||||
err = i.NextRaw(ctx, maxItemCount, &portalDocuments)
|
||||
return
|
||||
}
|
||||
|
||||
func (i *portalDocumentQueryIterator) NextRaw(ctx context.Context, maxItemCount int, raw interface{}) (err error) {
|
||||
if i.done {
|
||||
return
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
|
||||
headers.Set("X-Ms-Documentdb-Isquery", "True")
|
||||
headers.Set("Content-Type", "application/query+json")
|
||||
if i.partitionkey != "" {
|
||||
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+i.partitionkey+`"]`)
|
||||
} else {
|
||||
headers.Set("X-Ms-Documentdb-Query-Enablecrosspartition", "True")
|
||||
}
|
||||
if i.continuation != "" {
|
||||
headers.Set("X-Ms-Continuation", i.continuation)
|
||||
}
|
||||
|
||||
err = i.setOptions(i.options, nil, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = i.do(ctx, http.MethodPost, i.path+"/docs", "docs", i.path, http.StatusOK, &i.query, &raw, headers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
i.continuation = headers.Get("X-Ms-Continuation")
|
||||
i.done = i.continuation == ""
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (i *portalDocumentQueryIterator) Continuation() string {
|
||||
return i.continuation
|
||||
}
|
|
@ -0,0 +1,360 @@
|
|||
// Code generated by github.com/jim-minter/go-cosmosdb, DO NOT EDIT.
|
||||
|
||||
package cosmosdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/ugorji/go/codec"
|
||||
|
||||
pkg "github.com/Azure/ARO-RP/pkg/api"
|
||||
)
|
||||
|
||||
type fakePortalDocumentTriggerHandler func(context.Context, *pkg.PortalDocument) error
|
||||
type fakePortalDocumentQueryHandler func(PortalDocumentClient, *Query, *Options) PortalDocumentRawIterator
|
||||
|
||||
var _ PortalDocumentClient = &FakePortalDocumentClient{}
|
||||
|
||||
// NewFakePortalDocumentClient returns a FakePortalDocumentClient
|
||||
func NewFakePortalDocumentClient(h *codec.JsonHandle) *FakePortalDocumentClient {
|
||||
return &FakePortalDocumentClient{
|
||||
portalDocuments: make(map[string][]byte),
|
||||
triggerHandlers: make(map[string]fakePortalDocumentTriggerHandler),
|
||||
queryHandlers: make(map[string]fakePortalDocumentQueryHandler),
|
||||
jsonHandle: h,
|
||||
lock: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// FakePortalDocumentClient is a FakePortalDocumentClient
|
||||
type FakePortalDocumentClient struct {
|
||||
portalDocuments map[string][]byte
|
||||
jsonHandle *codec.JsonHandle
|
||||
lock *sync.RWMutex
|
||||
triggerHandlers map[string]fakePortalDocumentTriggerHandler
|
||||
queryHandlers map[string]fakePortalDocumentQueryHandler
|
||||
sorter func([]*pkg.PortalDocument)
|
||||
|
||||
// returns true if documents conflict
|
||||
conflictChecker func(*pkg.PortalDocument, *pkg.PortalDocument) bool
|
||||
|
||||
// err, if not nil, is an error to return when attempting to communicate
|
||||
// with this Client
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *FakePortalDocumentClient) decodePortalDocument(s []byte) (portalDocument *pkg.PortalDocument, err error) {
|
||||
err = codec.NewDecoderBytes(s, c.jsonHandle).Decode(&portalDocument)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *FakePortalDocumentClient) encodePortalDocument(portalDocument *pkg.PortalDocument) (b []byte, err error) {
|
||||
err = codec.NewEncoderBytes(&b, c.jsonHandle).Encode(portalDocument)
|
||||
return
|
||||
}
|
||||
|
||||
// SetError sets or unsets an error that will be returned on any
|
||||
// FakePortalDocumentClient method invocation
|
||||
func (c *FakePortalDocumentClient) SetError(err error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.err = err
|
||||
}
|
||||
|
||||
// SetSorter sets or unsets a sorter function which will be used to sort values
|
||||
// returned by List() for test stability
|
||||
func (c *FakePortalDocumentClient) SetSorter(sorter func([]*pkg.PortalDocument)) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.sorter = sorter
|
||||
}
|
||||
|
||||
// SetConflictChecker sets or unsets a function which can be used to validate
|
||||
// additional unique keys in a PortalDocument
|
||||
func (c *FakePortalDocumentClient) SetConflictChecker(conflictChecker func(*pkg.PortalDocument, *pkg.PortalDocument) bool) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.conflictChecker = conflictChecker
|
||||
}
|
||||
|
||||
// SetTriggerHandler sets or unsets a trigger handler
|
||||
func (c *FakePortalDocumentClient) SetTriggerHandler(triggerName string, trigger fakePortalDocumentTriggerHandler) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.triggerHandlers[triggerName] = trigger
|
||||
}
|
||||
|
||||
// SetQueryHandler sets or unsets a query handler
|
||||
func (c *FakePortalDocumentClient) SetQueryHandler(queryName string, query fakePortalDocumentQueryHandler) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.queryHandlers[queryName] = query
|
||||
}
|
||||
|
||||
func (c *FakePortalDocumentClient) deepCopy(portalDocument *pkg.PortalDocument) (*pkg.PortalDocument, error) {
|
||||
b, err := c.encodePortalDocument(portalDocument)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.decodePortalDocument(b)
|
||||
}
|
||||
|
||||
func (c *FakePortalDocumentClient) apply(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options, isCreate bool) (*pkg.PortalDocument, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
portalDocument, err := c.deepCopy(portalDocument) // copy now because pretriggers can mutate portalDocument
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options != nil {
|
||||
err := c.processPreTriggers(ctx, portalDocument, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_, exists := c.portalDocuments[portalDocument.ID]
|
||||
if isCreate && exists {
|
||||
return nil, &Error{
|
||||
StatusCode: http.StatusConflict,
|
||||
Message: "Entity with the specified id already exists in the system",
|
||||
}
|
||||
}
|
||||
if !isCreate && !exists {
|
||||
return nil, &Error{StatusCode: http.StatusNotFound}
|
||||
}
|
||||
|
||||
if c.conflictChecker != nil {
|
||||
for id := range c.portalDocuments {
|
||||
portalDocumentToCheck, err := c.decodePortalDocument(c.portalDocuments[id])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.conflictChecker(portalDocumentToCheck, portalDocument) {
|
||||
return nil, &Error{
|
||||
StatusCode: http.StatusConflict,
|
||||
Message: "Entity with the specified id already exists in the system",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b, err := c.encodePortalDocument(portalDocument)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.portalDocuments[portalDocument.ID] = b
|
||||
|
||||
return portalDocument, nil
|
||||
}
|
||||
|
||||
// Create creates a PortalDocument in the database
|
||||
func (c *FakePortalDocumentClient) Create(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (*pkg.PortalDocument, error) {
|
||||
return c.apply(ctx, partitionkey, portalDocument, options, true)
|
||||
}
|
||||
|
||||
// Replace replaces a PortalDocument in the database
|
||||
func (c *FakePortalDocumentClient) Replace(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (*pkg.PortalDocument, error) {
|
||||
return c.apply(ctx, partitionkey, portalDocument, options, false)
|
||||
}
|
||||
|
||||
// List returns a PortalDocumentIterator to list all PortalDocuments in the database
|
||||
func (c *FakePortalDocumentClient) List(*Options) PortalDocumentIterator {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if c.err != nil {
|
||||
return NewFakePortalDocumentErroringRawIterator(c.err)
|
||||
}
|
||||
|
||||
portalDocuments := make([]*pkg.PortalDocument, 0, len(c.portalDocuments))
|
||||
for _, d := range c.portalDocuments {
|
||||
r, err := c.decodePortalDocument(d)
|
||||
if err != nil {
|
||||
return NewFakePortalDocumentErroringRawIterator(err)
|
||||
}
|
||||
portalDocuments = append(portalDocuments, r)
|
||||
}
|
||||
|
||||
if c.sorter != nil {
|
||||
c.sorter(portalDocuments)
|
||||
}
|
||||
|
||||
return NewFakePortalDocumentIterator(portalDocuments, 0)
|
||||
}
|
||||
|
||||
// ListAll lists all PortalDocuments in the database
|
||||
func (c *FakePortalDocumentClient) ListAll(ctx context.Context, options *Options) (*pkg.PortalDocuments, error) {
|
||||
iter := c.List(options)
|
||||
return iter.Next(ctx, -1)
|
||||
}
|
||||
|
||||
// Get gets a PortalDocument from the database
|
||||
func (c *FakePortalDocumentClient) Get(ctx context.Context, partitionkey string, id string, options *Options) (*pkg.PortalDocument, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
portalDocument, exists := c.portalDocuments[id]
|
||||
if !exists {
|
||||
return nil, &Error{StatusCode: http.StatusNotFound}
|
||||
}
|
||||
|
||||
return c.decodePortalDocument(portalDocument)
|
||||
}
|
||||
|
||||
// Delete deletes a PortalDocument from the database
|
||||
func (c *FakePortalDocumentClient) Delete(ctx context.Context, partitionKey string, portalDocument *pkg.PortalDocument, options *Options) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.err != nil {
|
||||
return c.err
|
||||
}
|
||||
|
||||
_, exists := c.portalDocuments[portalDocument.ID]
|
||||
if !exists {
|
||||
return &Error{StatusCode: http.StatusNotFound}
|
||||
}
|
||||
|
||||
delete(c.portalDocuments, portalDocument.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeFeed is unimplemented
|
||||
func (c *FakePortalDocumentClient) ChangeFeed(*Options) PortalDocumentIterator {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if c.err != nil {
|
||||
return NewFakePortalDocumentErroringRawIterator(c.err)
|
||||
}
|
||||
|
||||
return NewFakePortalDocumentErroringRawIterator(ErrNotImplemented)
|
||||
}
|
||||
|
||||
func (c *FakePortalDocumentClient) processPreTriggers(ctx context.Context, portalDocument *pkg.PortalDocument, options *Options) error {
|
||||
for _, triggerName := range options.PreTriggers {
|
||||
if triggerHandler := c.triggerHandlers[triggerName]; triggerHandler != nil {
|
||||
err := triggerHandler(ctx, portalDocument)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query calls a query handler to implement database querying
|
||||
func (c *FakePortalDocumentClient) Query(name string, query *Query, options *Options) PortalDocumentRawIterator {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if c.err != nil {
|
||||
return NewFakePortalDocumentErroringRawIterator(c.err)
|
||||
}
|
||||
|
||||
if queryHandler := c.queryHandlers[query.Query]; queryHandler != nil {
|
||||
return queryHandler(c, query, options)
|
||||
}
|
||||
|
||||
return NewFakePortalDocumentErroringRawIterator(ErrNotImplemented)
|
||||
}
|
||||
|
||||
// QueryAll calls a query handler to implement database querying
|
||||
func (c *FakePortalDocumentClient) QueryAll(ctx context.Context, partitionkey string, query *Query, options *Options) (*pkg.PortalDocuments, error) {
|
||||
iter := c.Query("", query, options)
|
||||
return iter.Next(ctx, -1)
|
||||
}
|
||||
|
||||
func NewFakePortalDocumentIterator(portalDocuments []*pkg.PortalDocument, continuation int) PortalDocumentRawIterator {
|
||||
return &fakePortalDocumentIterator{portalDocuments: portalDocuments, continuation: continuation}
|
||||
}
|
||||
|
||||
type fakePortalDocumentIterator struct {
|
||||
portalDocuments []*pkg.PortalDocument
|
||||
continuation int
|
||||
done bool
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentIterator) NextRaw(ctx context.Context, maxItemCount int, out interface{}) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentIterator) Next(ctx context.Context, maxItemCount int) (*pkg.PortalDocuments, error) {
|
||||
if i.done {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var portalDocuments []*pkg.PortalDocument
|
||||
if maxItemCount == -1 {
|
||||
portalDocuments = i.portalDocuments[i.continuation:]
|
||||
i.continuation = len(i.portalDocuments)
|
||||
i.done = true
|
||||
} else {
|
||||
max := i.continuation + maxItemCount
|
||||
if max > len(i.portalDocuments) {
|
||||
max = len(i.portalDocuments)
|
||||
}
|
||||
portalDocuments = i.portalDocuments[i.continuation:max]
|
||||
i.continuation += max
|
||||
i.done = i.Continuation() == ""
|
||||
}
|
||||
|
||||
return &pkg.PortalDocuments{
|
||||
PortalDocuments: portalDocuments,
|
||||
Count: len(portalDocuments),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentIterator) Continuation() string {
|
||||
if i.continuation >= len(i.portalDocuments) {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d", i.continuation)
|
||||
}
|
||||
|
||||
// NewFakePortalDocumentErroringRawIterator returns a PortalDocumentRawIterator which
|
||||
// whose methods return the given error
|
||||
func NewFakePortalDocumentErroringRawIterator(err error) PortalDocumentRawIterator {
|
||||
return &fakePortalDocumentErroringRawIterator{err: err}
|
||||
}
|
||||
|
||||
type fakePortalDocumentErroringRawIterator struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentErroringRawIterator) Next(ctx context.Context, maxItemCount int) (*pkg.PortalDocuments, error) {
|
||||
return nil, i.err
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentErroringRawIterator) NextRaw(context.Context, int, interface{}) error {
|
||||
return i.err
|
||||
}
|
||||
|
||||
func (i *fakePortalDocumentErroringRawIterator) Continuation() string {
|
||||
return ""
|
||||
}
|
|
@ -30,6 +30,7 @@ const (
|
|||
collBilling = "Billing"
|
||||
collMonitors = "Monitors"
|
||||
collOpenShiftClusters = "OpenShiftClusters"
|
||||
collPortal = "Portal"
|
||||
collSubscriptions = "Subscriptions"
|
||||
)
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ type OpenShiftClusters interface {
|
|||
Delete(context.Context, *api.OpenShiftClusterDocument) error
|
||||
ChangeFeed() cosmosdb.OpenShiftClusterDocumentIterator
|
||||
List(string) cosmosdb.OpenShiftClusterDocumentIterator
|
||||
ListAll(context.Context) (*api.OpenShiftClusterDocuments, error)
|
||||
ListByPrefix(string, string, string) (cosmosdb.OpenShiftClusterDocumentIterator, error)
|
||||
Dequeue(context.Context) (*api.OpenShiftClusterDocument, error)
|
||||
Lease(context.Context, string) (*api.OpenShiftClusterDocument, error)
|
||||
|
@ -246,6 +247,10 @@ func (c *openShiftClusters) List(continuation string) cosmosdb.OpenShiftClusterD
|
|||
return c.c.List(&cosmosdb.Options{Continuation: continuation})
|
||||
}
|
||||
|
||||
func (c *openShiftClusters) ListAll(ctx context.Context) (*api.OpenShiftClusterDocuments, error) {
|
||||
return c.c.ListAll(ctx, nil)
|
||||
}
|
||||
|
||||
func (c *openShiftClusters) ListByPrefix(subscriptionID, prefix, continuation string) (cosmosdb.OpenShiftClusterDocumentIterator, error) {
|
||||
if prefix != strings.ToLower(prefix) {
|
||||
return nil, fmt.Errorf("prefix %q is not lower case", prefix)
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
package database
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
)
|
||||
|
||||
type portals struct {
|
||||
c cosmosdb.PortalDocumentClient
|
||||
}
|
||||
|
||||
// Portal is the database interface for PortalDocuments
|
||||
type Portal interface {
|
||||
Create(context.Context, *api.PortalDocument) (*api.PortalDocument, error)
|
||||
Get(context.Context, string) (*api.PortalDocument, error)
|
||||
Delete(context.Context, *api.PortalDocument) error
|
||||
}
|
||||
|
||||
// NewPortal returns a new Portal
|
||||
func NewPortal(ctx context.Context, deploymentMode deployment.Mode, dbc cosmosdb.DatabaseClient) (Portal, error) {
|
||||
dbid, err := databaseName(deploymentMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collc := cosmosdb.NewCollectionClient(dbc, dbid)
|
||||
|
||||
documentClient := cosmosdb.NewPortalDocumentClient(collc, collPortal)
|
||||
return NewPortalWithProvidedClient(documentClient), nil
|
||||
}
|
||||
|
||||
func NewPortalWithProvidedClient(client cosmosdb.PortalDocumentClient) Portal {
|
||||
return &portals{
|
||||
c: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *portals) Create(ctx context.Context, doc *api.PortalDocument) (*api.PortalDocument, error) {
|
||||
if doc.ID != strings.ToLower(doc.ID) {
|
||||
return nil, fmt.Errorf("id %q is not lower case", doc.ID)
|
||||
}
|
||||
|
||||
doc, err := c.c.Create(ctx, doc.ID, doc, nil)
|
||||
|
||||
if err, ok := err.(*cosmosdb.Error); ok && err.StatusCode == http.StatusConflict {
|
||||
err.StatusCode = http.StatusPreconditionFailed
|
||||
}
|
||||
|
||||
return doc, err
|
||||
}
|
||||
|
||||
func (c *portals) Get(ctx context.Context, id string) (*api.PortalDocument, error) {
|
||||
if id != strings.ToLower(id) {
|
||||
return nil, fmt.Errorf("id %q is not lower case", id)
|
||||
}
|
||||
|
||||
return c.c.Get(ctx, id, id, nil)
|
||||
}
|
||||
|
||||
func (c *portals) Delete(ctx context.Context, doc *api.PortalDocument) error {
|
||||
if doc.ID != strings.ToLower(doc.ID) {
|
||||
return fmt.Errorf("id %q is not lower case", doc.ID)
|
||||
}
|
||||
|
||||
return c.c.Delete(ctx, doc.ID, doc, nil)
|
||||
}
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
)
|
||||
|
@ -36,6 +37,7 @@ type Configuration struct {
|
|||
DatabaseAccountName *string `json:"databaseAccountName,omitempty" value:"required"`
|
||||
ExtraClusterKeyvaultAccessPolicies []interface{} `json:"extraClusterKeyvaultAccessPolicies,omitempty" value:"required"`
|
||||
ExtraCosmosDBIPs []string `json:"extraCosmosDBIPs,omitempty" value:"required"`
|
||||
ExtraPortalKeyvaultAccessPolicies []interface{} `json:"extraPortalKeyvaultAccessPolicies,omitempty" value:"required"`
|
||||
ExtraServiceKeyvaultAccessPolicies []interface{} `json:"extraServiceKeyvaultAccessPolicies,omitempty" value:"required"`
|
||||
FPServerCertCommonName *string `json:"fpServerCertCommonName,omitempty"`
|
||||
FPServicePrincipalID *string `json:"fpServicePrincipalId,omitempty" value:"required"`
|
||||
|
@ -46,6 +48,9 @@ type Configuration struct {
|
|||
MDMFrontendURL *string `json:"mdmFrontendUrl,omitempty" value:"required"`
|
||||
MDSDConfigVersion *string `json:"mdsdConfigVersion,omitempty" value:"required"`
|
||||
MDSDEnvironment *string `json:"mdsdEnvironment,omitempty" value:"required"`
|
||||
PortalAccessGroupIDs *string `json:"portalAccessGroupIds,omitempty" value:"required"`
|
||||
PortalClientID *string `json:"portalClientId,omitempty" value:"required"`
|
||||
PortalElevatedGroupIDs *string `json:"portalElevatedGroupIds,omitempty" value:"required"`
|
||||
RPImagePrefix *string `json:"rpImagePrefix,omitempty" value:"required"`
|
||||
RPMode *string `json:"rpMode,omitempty"`
|
||||
RPNSGSourceAddressPrefixes []string `json:"rpNsgSourceAddressPrefixes,omitempty" value:"required"`
|
||||
|
@ -119,5 +124,5 @@ func (conf *RPConfig) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Configuration has missing fields: %s", missingFields)
|
||||
return fmt.Errorf("configuration has missing fields: %s", strings.Join(missingFields, ","))
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ package deploy
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -114,87 +113,25 @@ func TestConfigNilable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConfigRequiredValues(t *testing.T) {
|
||||
AdminAPICABundle := "AdminAPICABundle"
|
||||
ACRResourceID := "ACRResourceID"
|
||||
ExtraCosmosDBIPs := "ExtraCosmosDBIPs"
|
||||
MDMFrontendURL := "MDMFrontendURL"
|
||||
AdminAPIClientCertCommonName := "AdminAPIClientCertCommonName"
|
||||
ClusterParentDomainName := "ClusterParentDomainName"
|
||||
DatabaseAccountName := "DatabaseAccountName"
|
||||
FPServerCertCommonName := "FPServerCertCommonName"
|
||||
FPServicePrincipalID := "FPServicePrincipalID"
|
||||
GlobalResourceGroupName := "GlobalResourceGroupName"
|
||||
GlobalResourceGroupLocation := "GlobalResourceGroupLocation"
|
||||
GlobalSubscriptionID := "GlobalSubscriptionID"
|
||||
KeyvaultPrefix := "KeyvaultPrefix"
|
||||
MDSDConfigVersion := "MDSDConfigVersion"
|
||||
MDSDEnvironment := "MDSDEnvironment"
|
||||
RPImagePrefix := "RPImagePrefix"
|
||||
RPMode := "RPMode"
|
||||
RPParentDomainName := "RPParentDomainName"
|
||||
RPVersionStorageAccountName := "RPVersionStorageAccountName"
|
||||
SSHPublicKey := "SSHPublicKey"
|
||||
SubscriptionResourceGroupName := "SubscriptionResourceGroupName"
|
||||
SubscriptionResourceGroupLocation := "SubscriptionResourceGroupLocation"
|
||||
StorageAccountDomain := "StorageAccountDomain"
|
||||
VMSize := "VMSize"
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
config RPConfig
|
||||
expect error
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: RPConfig{
|
||||
Configuration: &Configuration{
|
||||
ACRResourceID: &ACRResourceID,
|
||||
AdminAPICABundle: &AdminAPICABundle,
|
||||
ExtraCosmosDBIPs: []string{ExtraCosmosDBIPs},
|
||||
MDMFrontendURL: &MDMFrontendURL,
|
||||
ACRReplicaDisabled: to.BoolPtr(true),
|
||||
AdminAPIClientCertCommonName: &AdminAPIClientCertCommonName,
|
||||
ClusterParentDomainName: &ClusterParentDomainName,
|
||||
DatabaseAccountName: &DatabaseAccountName,
|
||||
ExtraClusterKeyvaultAccessPolicies: []interface{}{},
|
||||
ExtraServiceKeyvaultAccessPolicies: []interface{}{},
|
||||
FPServerCertCommonName: &FPServerCertCommonName,
|
||||
FPServicePrincipalID: &FPServicePrincipalID,
|
||||
GlobalResourceGroupName: &GlobalResourceGroupName,
|
||||
GlobalResourceGroupLocation: &GlobalResourceGroupLocation,
|
||||
GlobalSubscriptionID: &GlobalSubscriptionID,
|
||||
KeyvaultPrefix: &KeyvaultPrefix,
|
||||
MDSDConfigVersion: &MDSDConfigVersion,
|
||||
MDSDEnvironment: &MDSDEnvironment,
|
||||
RPImagePrefix: &RPImagePrefix,
|
||||
RPMode: &RPMode,
|
||||
RPNSGSourceAddressPrefixes: []string{},
|
||||
RPParentDomainName: &RPParentDomainName,
|
||||
RPVersionStorageAccountName: &RPVersionStorageAccountName,
|
||||
SSHPublicKey: &SSHPublicKey,
|
||||
SubscriptionResourceGroupName: &SubscriptionResourceGroupName,
|
||||
SubscriptionResourceGroupLocation: &SubscriptionResourceGroupLocation,
|
||||
StorageAccountDomain: &StorageAccountDomain,
|
||||
VMSize: &VMSize,
|
||||
},
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid config",
|
||||
config: RPConfig{
|
||||
Configuration: &Configuration{
|
||||
ACRResourceID: &ACRResourceID,
|
||||
AdminAPICABundle: &AdminAPICABundle,
|
||||
ExtraCosmosDBIPs: []string{ExtraCosmosDBIPs},
|
||||
Configuration: &Configuration{},
|
||||
},
|
||||
},
|
||||
expect: fmt.Errorf("Configuration has missing fields: %s", "[RPVersionStorageAccountName AdminAPIClientCertCommonName ClusterParentDomainName DatabaseAccountName ExtraClusterKeyvaultAccessPolicies ExtraServiceKeyvaultAccessPolicies FPServicePrincipalID GlobalResourceGroupName GlobalResourceGroupLocation GlobalSubscriptionID KeyvaultPrefix MDMFrontendURL MDSDConfigVersion MDSDEnvironment RPImagePrefix RPNSGSourceAddressPrefixes RPParentDomainName SubscriptionResourceGroupName SubscriptionResourceGroupLocation SSHPublicKey StorageAccountDomain VMSize]"),
|
||||
wantErr: "configuration has missing fields: ACRResourceID,RPVersionStorageAccountName,AdminAPICABundle,AdminAPIClientCertCommonName,ClusterParentDomainName,DatabaseAccountName,ExtraClusterKeyvaultAccessPolicies,ExtraCosmosDBIPs,ExtraPortalKeyvaultAccessPolicies,ExtraServiceKeyvaultAccessPolicies,FPServicePrincipalID,GlobalResourceGroupName,GlobalResourceGroupLocation,GlobalSubscriptionID,KeyvaultPrefix,MDMFrontendURL,MDSDConfigVersion,MDSDEnvironment,PortalAccessGroupIDs,PortalClientID,PortalElevatedGroupIDs,RPImagePrefix,RPNSGSourceAddressPrefixes,RPParentDomainName,SubscriptionResourceGroupName,SubscriptionResourceGroupLocation,SSHPublicKey,StorageAccountDomain,VMSize",
|
||||
},
|
||||
} {
|
||||
valid := tt.config.validate()
|
||||
if valid != tt.expect && valid.Error() != tt.expect.Error() {
|
||||
t.Errorf("Expected %s but got %s", tt.name, valid.Error())
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.validate()
|
||||
if err != nil && err.Error() != tt.wantErr ||
|
||||
err == nil && tt.wantErr != "" {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -180,6 +180,11 @@ func (d *deployer) configureDNS(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
portalPip, err := d.publicipaddresses.Get(ctx, d.config.ResourceGroupName, "portal-pip", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zone, err := d.zones.Get(ctx, d.config.ResourceGroupName, d.config.Location+"."+*d.config.Configuration.ClusterParentDomainName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -199,6 +204,20 @@ func (d *deployer) configureDNS(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = d.globalrecordsets.CreateOrUpdate(ctx, *d.config.Configuration.GlobalResourceGroupName, *d.config.Configuration.RPParentDomainName, "admin."+d.config.Location, mgmtdns.A, mgmtdns.RecordSet{
|
||||
RecordSetProperties: &mgmtdns.RecordSetProperties{
|
||||
TTL: to.Int64Ptr(3600),
|
||||
ARecords: &[]mgmtdns.ARecord{
|
||||
{
|
||||
Ipv4Address: portalPip.IPAddress,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsRecords := make([]mgmtdns.NsRecord, 0, len(*zone.NameServers))
|
||||
for i := range *zone.NameServers {
|
||||
nsRecords = append(nsRecords, mgmtdns.NsRecord{
|
||||
|
|
|
@ -6,6 +6,7 @@ package generator
|
|||
// Deployment constants
|
||||
const (
|
||||
ClustersKeyvaultSuffix = "-cls"
|
||||
PortalKeyvaultSuffix = "-por"
|
||||
ServiceKeyvaultSuffix = "-svc"
|
||||
)
|
||||
|
||||
|
|
|
@ -507,7 +507,7 @@ func (g *generator) pevnet() *arm.Resource {
|
|||
}
|
||||
}
|
||||
|
||||
func (g *generator) pip() *arm.Resource {
|
||||
func (g *generator) pip(name string) *arm.Resource {
|
||||
return &arm.Resource{
|
||||
Resource: &mgmtnetwork.PublicIPAddress{
|
||||
Sku: &mgmtnetwork.PublicIPAddressSku{
|
||||
|
@ -516,7 +516,7 @@ func (g *generator) pip() *arm.Resource {
|
|||
PublicIPAddressPropertiesFormat: &mgmtnetwork.PublicIPAddressPropertiesFormat{
|
||||
PublicIPAllocationMethod: mgmtnetwork.Static,
|
||||
},
|
||||
Name: to.StringPtr("rp-pip"),
|
||||
Name: &name,
|
||||
Type: to.StringPtr("Microsoft.Network/publicIPAddresses"),
|
||||
Location: to.StringPtr("[resourceGroup().location]"),
|
||||
},
|
||||
|
@ -541,6 +541,14 @@ func (g *generator) lb() *arm.Resource {
|
|||
},
|
||||
Name: to.StringPtr("rp-frontend"),
|
||||
},
|
||||
{
|
||||
FrontendIPConfigurationPropertiesFormat: &mgmtnetwork.FrontendIPConfigurationPropertiesFormat{
|
||||
PublicIPAddress: &mgmtnetwork.PublicIPAddress{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/publicIPAddresses', 'portal-pip')]"),
|
||||
},
|
||||
},
|
||||
Name: to.StringPtr("portal-frontend"),
|
||||
},
|
||||
},
|
||||
BackendAddressPools: &[]mgmtnetwork.BackendAddressPool{
|
||||
{
|
||||
|
@ -566,6 +574,42 @@ func (g *generator) lb() *arm.Resource {
|
|||
},
|
||||
Name: to.StringPtr("rp-lbrule"),
|
||||
},
|
||||
{
|
||||
LoadBalancingRulePropertiesFormat: &mgmtnetwork.LoadBalancingRulePropertiesFormat{
|
||||
FrontendIPConfiguration: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/frontendIPConfigurations', 'rp-lb', 'portal-frontend')]"),
|
||||
},
|
||||
BackendAddressPool: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/backendAddressPools', 'rp-lb', 'rp-backend')]"),
|
||||
},
|
||||
Probe: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/probes', 'rp-lb', 'portal-probe-https')]"),
|
||||
},
|
||||
Protocol: mgmtnetwork.TransportProtocolTCP,
|
||||
LoadDistribution: mgmtnetwork.LoadDistributionDefault,
|
||||
FrontendPort: to.Int32Ptr(443),
|
||||
BackendPort: to.Int32Ptr(444),
|
||||
},
|
||||
Name: to.StringPtr("portal-lbrule"),
|
||||
},
|
||||
{
|
||||
LoadBalancingRulePropertiesFormat: &mgmtnetwork.LoadBalancingRulePropertiesFormat{
|
||||
FrontendIPConfiguration: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/frontendIPConfigurations', 'rp-lb', 'portal-frontend')]"),
|
||||
},
|
||||
BackendAddressPool: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/backendAddressPools', 'rp-lb', 'rp-backend')]"),
|
||||
},
|
||||
Probe: &mgmtnetwork.SubResource{
|
||||
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/probes', 'rp-lb', 'portal-probe-ssh')]"),
|
||||
},
|
||||
Protocol: mgmtnetwork.TransportProtocolTCP,
|
||||
LoadDistribution: mgmtnetwork.LoadDistributionDefault,
|
||||
FrontendPort: to.Int32Ptr(22),
|
||||
BackendPort: to.Int32Ptr(2222),
|
||||
},
|
||||
Name: to.StringPtr("portal-lbrule"),
|
||||
},
|
||||
},
|
||||
Probes: &[]mgmtnetwork.Probe{
|
||||
{
|
||||
|
@ -577,6 +621,23 @@ func (g *generator) lb() *arm.Resource {
|
|||
},
|
||||
Name: to.StringPtr("rp-probe"),
|
||||
},
|
||||
{
|
||||
ProbePropertiesFormat: &mgmtnetwork.ProbePropertiesFormat{
|
||||
Protocol: mgmtnetwork.ProbeProtocolHTTPS,
|
||||
Port: to.Int32Ptr(444),
|
||||
NumberOfProbes: to.Int32Ptr(2),
|
||||
RequestPath: to.StringPtr("/healthz/ready"),
|
||||
},
|
||||
Name: to.StringPtr("portal-probe-https"),
|
||||
},
|
||||
{
|
||||
ProbePropertiesFormat: &mgmtnetwork.ProbePropertiesFormat{
|
||||
Protocol: mgmtnetwork.ProbeProtocolTCP,
|
||||
Port: to.Int32Ptr(2222),
|
||||
NumberOfProbes: to.Int32Ptr(2),
|
||||
},
|
||||
Name: to.StringPtr("portal-probe-ssh"),
|
||||
},
|
||||
},
|
||||
},
|
||||
Name: to.StringPtr("rp-lb"),
|
||||
|
@ -664,6 +725,9 @@ func (g *generator) vmss() *arm.Resource {
|
|||
"mdsdEnvironment",
|
||||
"acrResourceId",
|
||||
"domainName",
|
||||
"portalAccessGroupIds",
|
||||
"portalClientId",
|
||||
"portalElevatedGroupIds",
|
||||
"rpImage",
|
||||
"rpMode",
|
||||
"adminApiClientCertCommonName",
|
||||
|
@ -764,6 +828,8 @@ EOF
|
|||
sysctl --system
|
||||
|
||||
firewall-cmd --add-port=443/tcp --permanent
|
||||
firewall-cmd --add-port=444/tcp --permanent
|
||||
firewall-cmd --add-port=2222/tcp --permanent
|
||||
|
||||
cat >/etc/td-agent-bit/td-agent-bit.conf <<'EOF'
|
||||
[INPUT]
|
||||
|
@ -970,9 +1036,49 @@ StartLimitInterval=0
|
|||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
cat >/etc/sysconfig/aro-portal <<EOF
|
||||
MDM_ACCOUNT=AzureRedHatOpenShiftRP
|
||||
MDM_NAMESPACE=Portal
|
||||
AZURE_PORTAL_CLIENT_ID='$PORTALCLIENTID'
|
||||
AZURE_PORTAL_ACCESS_GROUP_IDS='$PORTALACCESSGROUPIDS'
|
||||
AZURE_PORTAL_ELEVATED_GROUP_IDS='$PORTALELEVATEDGROUPIDS'
|
||||
RPIMAGE='$RPIMAGE'
|
||||
RP_MODE='$RPMODE'
|
||||
EOF
|
||||
|
||||
cat >/etc/systemd/system/aro-portal.service <<'EOF'
|
||||
[Unit]
|
||||
After=docker.service
|
||||
Requires=docker.service
|
||||
StartLimitInterval=0
|
||||
|
||||
[Service]
|
||||
EnvironmentFile=/etc/sysconfig/aro-portal
|
||||
ExecStartPre=-/usr/bin/docker rm -f %N
|
||||
ExecStart=/usr/bin/docker run \
|
||||
--hostname %H \
|
||||
--name %N \
|
||||
--rm \
|
||||
-e ADMIN_API_CLIENT_CERT_COMMON_NAME \
|
||||
-e MDM_ACCOUNT \
|
||||
-e MDM_NAMESPACE \
|
||||
-e RP_MODE \
|
||||
-p 444:8444 \
|
||||
-p 2222:2222 \
|
||||
-v /run/systemd/journal:/run/systemd/journal \
|
||||
-v /var/etw:/var/etw:z \
|
||||
$RPIMAGE \
|
||||
portal
|
||||
Restart=always
|
||||
RestartSec=1
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
chcon -R system_u:object_r:var_log_t:s0 /var/opt/microsoft/linuxmonagent
|
||||
|
||||
for service in aro-monitor aro-rp auoms azsecd azsecmond mdsd mdm chronyd td-agent-bit; do
|
||||
for service in aro-monitor aro-portal aro-rp auoms azsecd azsecmond mdsd mdm chronyd td-agent-bit; do
|
||||
systemctl enable $service.service
|
||||
done
|
||||
|
||||
|
@ -1155,6 +1261,20 @@ func (g *generator) clusterKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyE
|
|||
}
|
||||
}
|
||||
|
||||
func (g *generator) portalKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyEntry {
|
||||
return []mgmtkeyvault.AccessPolicyEntry{
|
||||
{
|
||||
TenantID: &tenantUUIDHack,
|
||||
ObjectID: to.StringPtr("[parameters('rpServicePrincipalId')]"),
|
||||
Permissions: &mgmtkeyvault.Permissions{
|
||||
Secrets: &[]mgmtkeyvault.SecretPermissions{
|
||||
mgmtkeyvault.SecretPermissionsGet,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *generator) serviceKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyEntry {
|
||||
return []mgmtkeyvault.AccessPolicyEntry{
|
||||
{
|
||||
|
@ -1207,6 +1327,50 @@ func (g *generator) clustersKeyvault() *arm.Resource {
|
|||
}
|
||||
}
|
||||
|
||||
func (g *generator) portalKeyvault() *arm.Resource {
|
||||
vault := &mgmtkeyvault.Vault{
|
||||
Properties: &mgmtkeyvault.VaultProperties{
|
||||
EnableSoftDelete: to.BoolPtr(true),
|
||||
TenantID: &tenantUUIDHack,
|
||||
Sku: &mgmtkeyvault.Sku{
|
||||
Name: mgmtkeyvault.Standard,
|
||||
Family: to.StringPtr("A"),
|
||||
},
|
||||
AccessPolicies: &[]mgmtkeyvault.AccessPolicyEntry{},
|
||||
},
|
||||
Name: to.StringPtr("[concat(parameters('keyvaultPrefix'), '" + PortalKeyvaultSuffix + "')]"),
|
||||
Type: to.StringPtr("Microsoft.KeyVault/vaults"),
|
||||
Location: to.StringPtr("[resourceGroup().location]"),
|
||||
}
|
||||
|
||||
if !g.production {
|
||||
*vault.Properties.AccessPolicies = append(g.portalKeyvaultAccessPolicies(),
|
||||
mgmtkeyvault.AccessPolicyEntry{
|
||||
TenantID: &tenantUUIDHack,
|
||||
ObjectID: to.StringPtr("[parameters('adminObjectId')]"),
|
||||
Permissions: &mgmtkeyvault.Permissions{
|
||||
Certificates: &[]mgmtkeyvault.CertificatePermissions{
|
||||
mgmtkeyvault.Delete,
|
||||
mgmtkeyvault.Get,
|
||||
mgmtkeyvault.Import,
|
||||
mgmtkeyvault.List,
|
||||
},
|
||||
Secrets: &[]mgmtkeyvault.SecretPermissions{
|
||||
mgmtkeyvault.SecretPermissionsSet,
|
||||
mgmtkeyvault.SecretPermissionsList,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return &arm.Resource{
|
||||
Resource: vault,
|
||||
Condition: g.conditionStanza("fullDeploy"),
|
||||
APIVersion: azureclient.APIVersion("Microsoft.KeyVault"),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *generator) serviceKeyvault() *arm.Resource {
|
||||
vault := &mgmtkeyvault.Vault{
|
||||
Properties: &mgmtkeyvault.VaultProperties{
|
||||
|
@ -1304,6 +1468,36 @@ func (g *generator) cosmosdb() []*arm.Resource {
|
|||
}
|
||||
|
||||
func (g *generator) database(databaseName string, addDependsOn bool) []*arm.Resource {
|
||||
portal := &arm.Resource{
|
||||
Resource: &mgmtdocumentdb.SQLContainerCreateUpdateParameters{
|
||||
SQLContainerCreateUpdateProperties: &mgmtdocumentdb.SQLContainerCreateUpdateProperties{
|
||||
Resource: &mgmtdocumentdb.SQLContainerResource{
|
||||
ID: to.StringPtr("Portal"),
|
||||
PartitionKey: &mgmtdocumentdb.ContainerPartitionKey{
|
||||
Paths: &[]string{
|
||||
"/id",
|
||||
},
|
||||
Kind: mgmtdocumentdb.PartitionKindHash,
|
||||
},
|
||||
DefaultTTL: to.Int32Ptr(-1),
|
||||
},
|
||||
Options: map[string]*string{},
|
||||
},
|
||||
Name: to.StringPtr("[concat(parameters('databaseAccountName'), '/', " + databaseName + ", '/Portal')]"),
|
||||
Type: to.StringPtr("Microsoft.DocumentDB/databaseAccounts/sqlDatabases/containers"),
|
||||
Location: to.StringPtr("[resourceGroup().location]"),
|
||||
},
|
||||
Condition: g.conditionStanza("fullDeploy"),
|
||||
APIVersion: azureclient.APIVersion("Microsoft.DocumentDB"),
|
||||
DependsOn: []string{
|
||||
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), " + databaseName + ")]",
|
||||
},
|
||||
}
|
||||
|
||||
if g.production {
|
||||
portal.Resource.(*mgmtdocumentdb.SQLContainerCreateUpdateParameters).SQLContainerCreateUpdateProperties.Options["throughput"] = to.StringPtr("400")
|
||||
}
|
||||
|
||||
rs := []*arm.Resource{
|
||||
{
|
||||
Resource: &mgmtdocumentdb.SQLDatabaseCreateUpdateParameters{
|
||||
|
@ -1439,6 +1633,7 @@ func (g *generator) database(databaseName string, addDependsOn bool) []*arm.Reso
|
|||
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), " + databaseName + ")]",
|
||||
},
|
||||
},
|
||||
portal,
|
||||
{
|
||||
Resource: &mgmtdocumentdb.SQLContainerCreateUpdateParameters{
|
||||
SQLContainerCreateUpdateProperties: &mgmtdocumentdb.SQLContainerCreateUpdateProperties{
|
||||
|
|
|
@ -86,6 +86,9 @@ func (g *generator) rpTemplate() *arm.Template {
|
|||
"mdmFrontendUrl",
|
||||
"mdsdConfigVersion",
|
||||
"mdsdEnvironment",
|
||||
"portalAccessGroupIds",
|
||||
"portalClientId",
|
||||
"portalElevatedGroupIds",
|
||||
"rpImage",
|
||||
"rpMode",
|
||||
"sshPublicKey",
|
||||
|
@ -111,7 +114,11 @@ func (g *generator) rpTemplate() *arm.Template {
|
|||
}
|
||||
|
||||
if g.production {
|
||||
t.Resources = append(t.Resources, g.pip(), g.lb(), g.vmss(),
|
||||
t.Resources = append(t.Resources,
|
||||
g.pip("rp-pip"),
|
||||
g.pip("portal-pip"),
|
||||
g.lb(),
|
||||
g.vmss(),
|
||||
g.storageAccount(),
|
||||
g.lbAlert(30.0, 2, "rp-availability-alert", "PT5M", "PT15M", "DipAvailability"), // triggers on all 3 RPs being down for 10min, can't be >=0.3 due to deploys going down to 32% at times.
|
||||
g.lbAlert(67.0, 3, "rp-degraded-alert", "PT15M", "PT6H", "DipAvailability"), // 1/3 backend down for 1h or 2/3 down for 3h in the last 6h
|
||||
|
@ -270,6 +277,7 @@ func (g *generator) preDeployTemplate() *arm.Template {
|
|||
if g.production {
|
||||
t.Variables = map[string]interface{}{
|
||||
"clusterKeyvaultAccessPolicies": g.clusterKeyvaultAccessPolicies(),
|
||||
"portalKeyvaultAccessPolicies": g.portalKeyvaultAccessPolicies(),
|
||||
"serviceKeyvaultAccessPolicies": g.serviceKeyvaultAccessPolicies(),
|
||||
}
|
||||
}
|
||||
|
@ -284,6 +292,7 @@ func (g *generator) preDeployTemplate() *arm.Template {
|
|||
params = append(params,
|
||||
"deployNSGs",
|
||||
"extraClusterKeyvaultAccessPolicies",
|
||||
"extraPortalKeyvaultAccessPolicies",
|
||||
"extraServiceKeyvaultAccessPolicies",
|
||||
"fullDeploy",
|
||||
"rpNsgSourceAddressPrefixes",
|
||||
|
@ -300,7 +309,9 @@ func (g *generator) preDeployTemplate() *arm.Template {
|
|||
case "deployNSGs":
|
||||
p.Type = "bool"
|
||||
p.DefaultValue = false
|
||||
case "extraClusterKeyvaultAccessPolicies", "extraServiceKeyvaultAccessPolicies":
|
||||
case "extraClusterKeyvaultAccessPolicies",
|
||||
"extraPortalKeyvaultAccessPolicies",
|
||||
"extraServiceKeyvaultAccessPolicies":
|
||||
p.Type = "array"
|
||||
p.DefaultValue = []interface{}{}
|
||||
case "fullDeploy":
|
||||
|
@ -318,9 +329,10 @@ func (g *generator) preDeployTemplate() *arm.Template {
|
|||
t.Resources = append(t.Resources,
|
||||
g.securityGroupRP(),
|
||||
g.securityGroupPE(),
|
||||
// clustersKeyvault must preceed serviceKeyvault due to terrible
|
||||
// bytes.Replace in templateFixup
|
||||
// clustersKeyvault, portalKeyvault and serviceKeyvault must be in this
|
||||
// order due to terrible bytes.Replace in templateFixup
|
||||
g.clustersKeyvault(),
|
||||
g.portalKeyvault(),
|
||||
g.serviceKeyvault(),
|
||||
)
|
||||
|
||||
|
@ -409,6 +421,7 @@ func (g *generator) templateFixup(t *arm.Template) ([]byte, error) {
|
|||
b = bytes.ReplaceAll(b, []byte(`"capacity": 1337`), []byte(`"capacity": "[int(parameters('ciCapacity'))]"`))
|
||||
if g.production {
|
||||
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('clusterKeyvaultAccessPolicies'), parameters('extraClusterKeyvaultAccessPolicies'))]"`), 1)
|
||||
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('portalKeyvaultAccessPolicies'), parameters('extraPortalKeyvaultAccessPolicies'))]"`), 1)
|
||||
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('serviceKeyvaultAccessPolicies'), parameters('extraServiceKeyvaultAccessPolicies'))]"`), 1)
|
||||
b = bytes.Replace(b, []byte(`"sourceAddressPrefixes": []`), []byte(`"sourceAddressPrefixes": "[parameters('rpNsgSourceAddressPrefixes')]"`), 1)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ package deploy
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
|
@ -307,7 +309,17 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return d.ensureSecret(ctx, secrets, env.FrontendEncryptionSecretName)
|
||||
err = d.ensureSecret(ctx, secrets, env.FrontendEncryptionSecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = d.ensureSecret(ctx, secrets, env.PortalServerSessionKeySecretName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.ensureSecretKey(ctx, secrets, env.PortalServerSSHKeySecretName)
|
||||
}
|
||||
|
||||
func (d *deployer) ensureSecret(ctx context.Context, existingSecrets []keyvault.SecretItem, secretName string) error {
|
||||
|
@ -328,3 +340,21 @@ func (d *deployer) ensureSecret(ctx context.Context, existingSecrets []keyvault.
|
|||
Value: to.StringPtr(base64.StdEncoding.EncodeToString(key)),
|
||||
})
|
||||
}
|
||||
|
||||
func (d *deployer) ensureSecretKey(ctx context.Context, existingSecrets []keyvault.SecretItem, secretName string) error {
|
||||
for _, secret := range existingSecrets {
|
||||
if filepath.Base(*secret.ID) == secretName {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Infof("setting %s", secretName)
|
||||
return d.keyvault.SetSecret(ctx, secretName, keyvault.SecretSetParameters{
|
||||
Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -26,6 +26,10 @@ const (
|
|||
FrontendEncryptionSecretName = "fe-encryption-key"
|
||||
RPLoggingSecretName = "rp-mdsd"
|
||||
RPMonitoringSecretName = "rp-mdm"
|
||||
PortalServerSecretName = "portal-server"
|
||||
PortalServerClientSecretName = "portal-client"
|
||||
PortalServerSessionKeySecretName = "portal-session-key"
|
||||
PortalServerSSHKeySecretName = "portal-sshkey"
|
||||
)
|
||||
|
||||
type Interface interface {
|
||||
|
|
|
@ -4,5 +4,5 @@ package env
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
//go:generate rm -rf ../util/mocks/$GOPACKAGE
|
||||
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Interface
|
||||
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Core,Interface
|
||||
//go:generate go run ../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../util/mocks/$GOPACKAGE/$GOPACKAGE.go
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
|
||||
<link rel="stylesheet" href="lib/bootstrap-4.5.2.min.css">
|
||||
|
||||
<title>ARO SRE portal ({{ .location }})</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="navbar navbar-light bg-light shadow-sm">
|
||||
<div class="navbar-brand">
|
||||
<strong>ARO SRE portal ({{ .location }})</strong>
|
||||
</div>
|
||||
|
||||
<button class="btn btn-secondary" id="btnLogout">Logout</button>
|
||||
</div>
|
||||
|
||||
<div class="container py-4">
|
||||
<div class="form-group">
|
||||
<label for="selResourceId">Cluster:</label>
|
||||
<div class="col-sm-10">
|
||||
<select class="form-control form-control-sm" id="selResourceId">
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="selMaster">Master:</label>
|
||||
<div class="col-sm-10">
|
||||
<select class="form-control form-control-sm" id="selMaster">
|
||||
<option value="0">master-0</option>
|
||||
<option value="1">master-1</option>
|
||||
<option value="2">master-2</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button class="btn btn-secondary" id="btnPrometheus">Prometheus</button>
|
||||
|
||||
<button class="btn btn-secondary" id="btnKubeconfig">Kubeconfig</button>
|
||||
|
||||
<button class="btn btn-secondary" id="btnSSH">SSH</button>
|
||||
|
||||
<div class="py-4" id="divAlerts"></div>
|
||||
</div>
|
||||
|
||||
<template id="tmplSSHAlert">
|
||||
<div class="alert alert-primary alert-dismissible fade show" role="alert">
|
||||
<div>
|
||||
<button class="btn btn-secondary copy-button">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white" width="18px" height="18px">
|
||||
<path d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M16 1H4c-1.1 0-2 .9-2 2v14h2V3h12V1zm3 4H8c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h11c1.1 0 2-.9 2-2V7c0-1.1-.9-2-2-2zm0 16H8V7h11v14z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<span data-copy="command">Command: <code></code></span>
|
||||
</div>
|
||||
<div class="py-2">
|
||||
<button class="btn btn-secondary copy-button">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white" width="18px" height="18px">
|
||||
<path d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M16 1H4c-1.1 0-2 .9-2 2v14h2V3h12V1zm3 4H8c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h11c1.1 0 2-.9 2-2V7c0-1.1-.9-2-2-2zm0 16H8V7h11v14z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<span data-copy="password">Password: <code></code></span>
|
||||
</div>
|
||||
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
|
||||
<span aria-hidden="true">×</span>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template id="tmplSSHAlertError">
|
||||
<div class="alert alert-danger alert-dismissible fade show" role="alert">
|
||||
<span data-copy="error"></span>
|
||||
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
|
||||
<span aria-hidden="true">×</span>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
{{ .csrfField }}
|
||||
|
||||
<script src="lib/jquery-3.5.1.min.js"></script>
|
||||
<script src="lib/bootstrap-4.5.2.min.js"></script>
|
||||
<script src="index.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -0,0 +1,89 @@
|
|||
$.extend({
|
||||
redirect: function (location, args) {
|
||||
var form = $("<form method='POST' style='display: none;'></form>");
|
||||
form.attr("action", location);
|
||||
|
||||
$.each(args || {}, function (key, value) {
|
||||
var input = $("<input name='hidden'></input>");
|
||||
|
||||
input.attr("name", key);
|
||||
input.attr("value", value);
|
||||
|
||||
form.append(input);
|
||||
});
|
||||
|
||||
form.append($("input[name='gorilla.csrf.Token']").first());
|
||||
form.appendTo("body").submit();
|
||||
}
|
||||
});
|
||||
|
||||
$(document).ready(function () {
|
||||
$.ajax({
|
||||
url: "/api/clusters",
|
||||
success: function (clusters) {
|
||||
$.each(clusters, function (i, cluster) {
|
||||
$("#selResourceId").append($("<option>").text(cluster));
|
||||
});
|
||||
},
|
||||
dataType: "json",
|
||||
});
|
||||
|
||||
$("#btnLogout").click(function () {
|
||||
$.redirect("/api/logout");
|
||||
});
|
||||
|
||||
$("#btnKubeconfig").click(function () {
|
||||
$.redirect($("#selResourceId").val() + "/kubeconfig/new");
|
||||
});
|
||||
|
||||
$("#btnPrometheus").click(function () {
|
||||
window.location = $("#selResourceId").val() + "/prometheus";
|
||||
});
|
||||
|
||||
$("#btnSSH").click(function () {
|
||||
$.ajax({
|
||||
method: "POST",
|
||||
url: $("#selResourceId").val() + "/ssh/new",
|
||||
headers: {
|
||||
"X-CSRF-Token": $("input[name='gorilla.csrf.Token']").val(),
|
||||
},
|
||||
contentType: "application/json",
|
||||
data: JSON.stringify({
|
||||
"master": parseInt($("#selMaster").val()),
|
||||
}),
|
||||
success: function (reply) {
|
||||
if (reply["error"]) {
|
||||
var template = $("#tmplSSHAlertError").html();
|
||||
var alert = $(template);
|
||||
|
||||
alert.find("span[data-copy='error']").text(reply["error"]);
|
||||
$("#divAlerts").html(alert);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
var template = $("#tmplSSHAlert").html();
|
||||
var alert = $(template);
|
||||
|
||||
alert.find("span[data-copy='command'] > code").text(reply["command"]);
|
||||
alert.find("span[data-copy='command']").attr("data-copy", reply["command"]);
|
||||
alert.find("span[data-copy='password'] > code").text("********");
|
||||
alert.find("span[data-copy='password']").attr("data-copy", reply["password"]);
|
||||
$("#divAlerts").html(alert);
|
||||
|
||||
$('.copy-button').click(function () {
|
||||
var textarea = $("<textarea class='style: hidden;' id='textarea'></textarea>");
|
||||
textarea.text($(this).next().attr("data-copy"));
|
||||
textarea.appendTo("body");
|
||||
|
||||
textarea = document.getElementById("textarea")
|
||||
textarea.select();
|
||||
textarea.setSelectionRange(0, textarea.value.length + 1);
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(textarea)
|
||||
});
|
||||
},
|
||||
dataType: "json",
|
||||
});
|
||||
});
|
||||
});
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,7 @@
|
|||
package portal
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
//go:generate go run ../../vendor/github.com/go-bindata/go-bindata/go-bindata -nometadata -pkg $GOPACKAGE -prefix assets assets/...
|
||||
//go:generate gofmt -s -l -w bindata.go
|
|
@ -0,0 +1,183 @@
|
|||
package kubeconfig
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/api/validate"
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/clientcache"
|
||||
"github.com/Azure/ARO-RP/pkg/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
|
||||
)
|
||||
|
||||
const (
|
||||
kubeconfigNewTimeout = 6 * time.Hour
|
||||
)
|
||||
|
||||
type kubeconfig struct {
|
||||
log *logrus.Entry
|
||||
baseAccessLog *logrus.Entry
|
||||
|
||||
servingCert *x509.Certificate
|
||||
elevatedGroupIDs []string
|
||||
|
||||
dbOpenShiftClusters database.OpenShiftClusters
|
||||
dbPortal database.Portal
|
||||
|
||||
dialer proxy.Dialer
|
||||
clientCache clientcache.ClientCache
|
||||
|
||||
newToken func() string
|
||||
}
|
||||
|
||||
func New(baseLog *logrus.Entry,
|
||||
baseAccessLog *logrus.Entry,
|
||||
servingCert *x509.Certificate,
|
||||
elevatedGroupIDs []string,
|
||||
dbOpenShiftClusters database.OpenShiftClusters,
|
||||
dbPortal database.Portal,
|
||||
dialer proxy.Dialer,
|
||||
aadAuthenticatedRouter,
|
||||
unauthenticatedRouter *mux.Router) *kubeconfig {
|
||||
k := &kubeconfig{
|
||||
log: baseLog,
|
||||
baseAccessLog: baseAccessLog,
|
||||
|
||||
servingCert: servingCert,
|
||||
elevatedGroupIDs: elevatedGroupIDs,
|
||||
|
||||
dbOpenShiftClusters: dbOpenShiftClusters,
|
||||
dbPortal: dbPortal,
|
||||
|
||||
dialer: dialer,
|
||||
clientCache: clientcache.New(time.Hour),
|
||||
|
||||
newToken: func() string { return uuid.NewV4().String() },
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: k.director,
|
||||
Transport: roundtripper.RoundTripperFunc(k.roundTripper),
|
||||
ErrorLog: log.New(k.log.Writer(), "", 0),
|
||||
}
|
||||
|
||||
aadAuthenticatedRouter.NewRoute().Methods(http.MethodPost).Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/kubeconfig/new").HandlerFunc(k.new)
|
||||
|
||||
bearerAuthenticatedRouter := unauthenticatedRouter.NewRoute().Subrouter()
|
||||
bearerAuthenticatedRouter.Use(middleware.Bearer(k.dbPortal))
|
||||
bearerAuthenticatedRouter.Use(middleware.Log(k.baseAccessLog))
|
||||
|
||||
bearerAuthenticatedRouter.PathPrefix("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/kubeconfig/proxy/").Handler(rp)
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
// new creates a new PortalDocument allowing kubeconfig access to a cluster for
|
||||
// 6 hours and returns a kubeconfig with the temporary credentials
|
||||
func (k *kubeconfig) new(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
|
||||
if !validate.RxClusterID.MatchString(resourceID) {
|
||||
http.Error(w, fmt.Sprintf("invalid resourceId %q", resourceID), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
elevated := middleware.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))
|
||||
|
||||
token := k.newToken()
|
||||
|
||||
portalDoc := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: int(kubeconfigNewTimeout / time.Second),
|
||||
Portal: &api.Portal{
|
||||
Username: ctx.Value(middleware.ContextKeyUsername).(string),
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{
|
||||
Elevated: elevated,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := k.dbPortal.Create(ctx, portalDoc)
|
||||
if err != nil {
|
||||
k.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
b, err := k.makeKubeconfig("https://"+r.Host+resourceID+"/kubeconfig/proxy", token)
|
||||
if err != nil {
|
||||
k.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filename := strings.Split(r.URL.Path, "/")[8]
|
||||
if elevated {
|
||||
filename += "-elevated"
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.Header().Add("Content-Disposition", `attachment; filename="`+filename+`.kubeconfig"`)
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
func (k *kubeconfig) internalServerError(w http.ResponseWriter, err error) {
|
||||
k.log.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func (k *kubeconfig) makeKubeconfig(server, token string) ([]byte, error) {
|
||||
return json.MarshalIndent(&v1.Config{
|
||||
APIVersion: "v1",
|
||||
Kind: "Config",
|
||||
Clusters: []v1.NamedCluster{
|
||||
{
|
||||
Name: "cluster",
|
||||
Cluster: v1.Cluster{
|
||||
Server: server,
|
||||
CertificateAuthorityData: pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: k.servingCert.Raw,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
AuthInfos: []v1.NamedAuthInfo{
|
||||
{
|
||||
Name: "user",
|
||||
AuthInfo: v1.AuthInfo{
|
||||
Token: token,
|
||||
},
|
||||
},
|
||||
},
|
||||
Contexts: []v1.NamedContext{
|
||||
{
|
||||
Name: "context",
|
||||
Context: v1.Context{
|
||||
Cluster: "cluster",
|
||||
Namespace: "default",
|
||||
AuthInfo: "user",
|
||||
},
|
||||
},
|
||||
},
|
||||
CurrentContext: "context",
|
||||
}, "", " ")
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
package kubeconfig
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
|
||||
elevatedGroupIDs := []string{"10000000-0000-0000-0000-000000000000"}
|
||||
username := "username"
|
||||
password := "password"
|
||||
|
||||
servingCert := &x509.Certificate{}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
r func(*http.Request)
|
||||
elevated bool
|
||||
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakePortalDocumentClient)
|
||||
wantStatusCode int
|
||||
wantHeaders http.Header
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "success - not elevated",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: password,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{},
|
||||
},
|
||||
}
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantHeaders: http.Header{
|
||||
"Content-Disposition": []string{`attachment; filename="cluster.kubeconfig"`},
|
||||
},
|
||||
wantBody: "{\n \"kind\": \"Config\",\n \"apiVersion\": \"v1\",\n \"preferences\": {},\n \"clusters\": [\n {\n \"name\": \"cluster\",\n \"cluster\": {\n \"server\": \"https://localhost:8444/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy\",\n \"certificate-authority-data\": \"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K\"\n }\n }\n ],\n \"users\": [\n {\n \"name\": \"user\",\n \"user\": {\n \"token\": \"password\"\n }\n }\n ],\n \"contexts\": [\n {\n \"name\": \"context\",\n \"context\": {\n \"cluster\": \"cluster\",\n \"user\": \"user\",\n \"namespace\": \"default\"\n }\n }\n ],\n \"current-context\": \"context\"\n}",
|
||||
},
|
||||
{
|
||||
name: "success - elevated",
|
||||
elevated: true,
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: password,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{
|
||||
Elevated: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantHeaders: http.Header{
|
||||
"Content-Disposition": []string{`attachment; filename="cluster-elevated.kubeconfig"`},
|
||||
},
|
||||
wantBody: "{\n \"kind\": \"Config\",\n \"apiVersion\": \"v1\",\n \"preferences\": {},\n \"clusters\": [\n {\n \"name\": \"cluster\",\n \"cluster\": {\n \"server\": \"https://localhost:8444/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy\",\n \"certificate-authority-data\": \"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K\"\n }\n }\n ],\n \"users\": [\n {\n \"name\": \"user\",\n \"user\": {\n \"token\": \"password\"\n }\n }\n ],\n \"contexts\": [\n {\n \"name\": \"context\",\n \"context\": {\n \"cluster\": \"cluster\",\n \"user\": \"user\",\n \"namespace\": \"default\"\n }\n }\n ],\n \"current-context\": \"context\"\n}",
|
||||
},
|
||||
{
|
||||
name: "bad path",
|
||||
r: func(r *http.Request) {
|
||||
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/new"
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "invalid resourceId \"/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster\"\n",
|
||||
},
|
||||
{
|
||||
name: "sad database",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantBody: "Internal Server Error\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
dbPortal, portalClient := testdatabase.NewFakePortal()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithPortal(dbPortal)
|
||||
|
||||
checker := testdatabase.NewChecker()
|
||||
|
||||
if tt.fixtureChecker != nil {
|
||||
tt.fixtureChecker(fixture, checker, portalClient)
|
||||
}
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, middleware.ContextKeyUsername, username)
|
||||
if tt.elevated {
|
||||
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, elevatedGroupIDs)
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, []string(nil))
|
||||
}
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
"https://localhost:8444"+resourceID+"/kubeconfig/new", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
aadAuthenticatedRouter := &mux.Router{}
|
||||
|
||||
k := New(logrus.NewEntry(logrus.StandardLogger()), nil, servingCert, elevatedGroupIDs, nil, dbPortal, nil, aadAuthenticatedRouter, &mux.Router{})
|
||||
|
||||
k.newToken = func() string { return password }
|
||||
|
||||
if tt.r != nil {
|
||||
tt.r(r)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
|
||||
aadAuthenticatedRouter.ServeHTTP(w, r)
|
||||
|
||||
portalClient.SetError(nil)
|
||||
|
||||
for _, err = range checker.CheckPortals(portalClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
resp := w.Response()
|
||||
|
||||
if resp.StatusCode != tt.wantStatusCode {
|
||||
t.Error(resp.StatusCode)
|
||||
}
|
||||
|
||||
for k, v := range tt.wantHeaders {
|
||||
if !reflect.DeepEqual(resp.Header[k], v) {
|
||||
t.Errorf(k, resp.Header[k], v)
|
||||
}
|
||||
}
|
||||
|
||||
wantContentType := "application/json"
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
wantContentType = "text/plain; charset=utf-8"
|
||||
}
|
||||
if resp.Header.Get("Content-Type") != wantContentType {
|
||||
t.Error(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(b) != tt.wantBody {
|
||||
t.Errorf("%q", string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
package kubeconfig
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ghodss/yaml"
|
||||
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/api/validate"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
"github.com/Azure/ARO-RP/pkg/util/pem"
|
||||
"github.com/Azure/ARO-RP/pkg/util/restconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
kubeconfigTimeout = time.Hour
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyClient contextKey = iota
|
||||
contextKeyResponse
|
||||
)
|
||||
|
||||
// director is called by the ReverseProxy. It converts an incoming request into
|
||||
// the one that'll go out to the API server. It also resolves an HTTP client
|
||||
// that will be able to make the ongoing request.
|
||||
//
|
||||
// Unfortunately the signature of httputil.ReverseProxy.Director does not allow
|
||||
// us to return values. We get around this limitation slightly naughtily by
|
||||
// storing return information in the request context.
|
||||
func (k *kubeconfig) director(r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
portalDoc, _ := ctx.Value(middleware.ContextKeyPortalDoc).(*api.PortalDocument)
|
||||
if portalDoc == nil || portalDoc.Portal.Kubeconfig == nil {
|
||||
k.error(r, http.StatusForbidden, nil)
|
||||
return
|
||||
}
|
||||
|
||||
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
|
||||
if !validate.RxClusterID.MatchString(resourceID) ||
|
||||
!strings.EqualFold(resourceID, portalDoc.Portal.ID) {
|
||||
k.error(r, http.StatusBadRequest, nil)
|
||||
return
|
||||
}
|
||||
|
||||
key := struct {
|
||||
resourceID string
|
||||
elevated bool
|
||||
}{
|
||||
resourceID: portalDoc.Portal.ID,
|
||||
elevated: portalDoc.Portal.Kubeconfig.Elevated,
|
||||
}
|
||||
|
||||
cli := k.clientCache.Get(key)
|
||||
if cli == nil {
|
||||
var err error
|
||||
cli, err = k.cli(ctx, key.resourceID, key.elevated)
|
||||
if err != nil {
|
||||
k.error(r, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
k.clientCache.Put(key, cli)
|
||||
}
|
||||
|
||||
r.RequestURI = ""
|
||||
r.URL.Scheme = "https"
|
||||
r.URL.Host = "kubernetes:6443"
|
||||
r.URL.Path = "/" + strings.Join(strings.Split(r.URL.Path, "/")[11:], "/")
|
||||
r.Header.Del("Authorization")
|
||||
r.Host = r.URL.Host
|
||||
|
||||
// http.Request.WithContext returns a copy of the original Request with the
|
||||
// new context, but we have no way to return it, so we overwrite our
|
||||
// existing request.
|
||||
*r = *r.WithContext(context.WithValue(ctx, contextKeyClient, cli))
|
||||
}
|
||||
|
||||
// cli returns an appropriately configured HTTP client for forwarding the
|
||||
// incoming request to a cluster
|
||||
func (k *kubeconfig) cli(ctx context.Context, resourceID string, elevated bool) (*http.Client, error) {
|
||||
openShiftDoc, err := k.dbOpenShiftClusters.Get(ctx, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kc := openShiftDoc.OpenShiftCluster.Properties.AROSREKubeconfig
|
||||
if elevated {
|
||||
kc = openShiftDoc.OpenShiftCluster.Properties.AROServiceKubeconfig
|
||||
}
|
||||
|
||||
var kubeconfig *v1.Config
|
||||
err = yaml.Unmarshal(kc, &kubeconfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var b []byte
|
||||
b = append(b, kubeconfig.AuthInfos[0].AuthInfo.ClientKeyData...)
|
||||
b = append(b, kubeconfig.AuthInfos[0].AuthInfo.ClientCertificateData...)
|
||||
|
||||
clientKey, clientCerts, err := pem.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, caCerts, err := pem.Parse(kubeconfig.Clusters[0].Cluster.CertificateAuthorityData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
for _, caCert := range caCerts {
|
||||
pool.AddCert(caCert)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: restconfig.DialContext(k.dialer, openShiftDoc.OpenShiftCluster),
|
||||
TLSClientConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
{
|
||||
Certificate: [][]byte{
|
||||
clientCerts[0].Raw,
|
||||
},
|
||||
PrivateKey: clientKey,
|
||||
},
|
||||
},
|
||||
RootCAs: pool,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// roundTripper is called by ReverseProxy to make the onward request happen. We
|
||||
// check if we had an error earlier and return that if we did. Otherwise we dig
|
||||
// out the client and call it.
|
||||
func (k *kubeconfig) roundTripper(r *http.Request) (*http.Response, error) {
|
||||
if resp, ok := r.Context().Value(contextKeyResponse).(*http.Response); ok {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
cli := r.Context().Value(contextKeyClient).(*http.Client)
|
||||
resp, err := cli.Do(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
resp.Body = newCancelBody(resp.Body.(io.ReadWriteCloser), kubeconfigTimeout)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (k *kubeconfig) error(r *http.Request, statusCode int, err error) {
|
||||
if err != nil {
|
||||
k.log.Warn(err)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
|
||||
*r = *r.WithContext(context.WithValue(r.Context(), contextKeyResponse, w.Response()))
|
||||
}
|
||||
|
||||
// cancelBody is a workaround for the fact that http timeouts are incompatible
|
||||
// with hijacked connections (https://github.com/golang/go/issues/31391):
|
||||
// net/http.cancelTimerBody does not implement Writer.
|
||||
type cancelBody struct {
|
||||
io.ReadWriteCloser
|
||||
t *time.Timer
|
||||
c chan struct{}
|
||||
}
|
||||
|
||||
func (b *cancelBody) wait() {
|
||||
select {
|
||||
case <-b.t.C:
|
||||
b.ReadWriteCloser.Close()
|
||||
case <-b.c:
|
||||
b.t.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *cancelBody) Close() error {
|
||||
select {
|
||||
case b.c <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return b.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
func newCancelBody(rwc io.ReadWriteCloser, d time.Duration) io.ReadWriteCloser {
|
||||
b := &cancelBody{
|
||||
ReadWriteCloser: rwc,
|
||||
t: time.NewTimer(d),
|
||||
c: make(chan struct{}),
|
||||
}
|
||||
|
||||
go b.wait()
|
||||
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,403 @@
|
|||
package kubeconfig
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
|
||||
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
"github.com/Azure/ARO-RP/test/util/listener"
|
||||
)
|
||||
|
||||
// fakeServer returns a test listener for an HTTPS server which validates its
|
||||
// client and echos back the request it received
|
||||
func fakeServer(cacerts []*x509.Certificate, serverkey *rsa.PrivateKey, servercerts []*x509.Certificate) *listener.Listener {
|
||||
l := listener.NewListener()
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cacerts[0])
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(tls.NewListener(l, &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
{
|
||||
Certificate: [][]byte{
|
||||
servercerts[0].Raw,
|
||||
},
|
||||
PrivateKey: serverkey,
|
||||
},
|
||||
},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
}), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Header.Add("X-Authenticated-Name", r.TLS.PeerCertificates[0].Subject.CommonName)
|
||||
b, _ := httputil.DumpRequest(r, true)
|
||||
_, _ = w.Write(b)
|
||||
}))
|
||||
}()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func testKubeconfig(cacerts []*x509.Certificate, clientkey *rsa.PrivateKey, clientcerts []*x509.Certificate) ([]byte, error) {
|
||||
kc := &v1.Config{
|
||||
Clusters: []v1.NamedCluster{
|
||||
{},
|
||||
},
|
||||
AuthInfos: []v1.NamedAuthInfo{
|
||||
{},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
kc.AuthInfos[0].AuthInfo.ClientKeyData, err = utiltls.PrivateKeyAsBytes(clientkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kc.AuthInfos[0].AuthInfo.ClientCertificateData, err = utiltls.CertAsBytes(clientcerts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kc.Clusters[0].Cluster.CertificateAuthorityData, err = utiltls.CertAsBytes(cacerts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(kc)
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
|
||||
username := "username"
|
||||
token := "00000000-0000-0000-0000-000000000000"
|
||||
privateEndpointIP := "1.2.3.4"
|
||||
|
||||
cakey, cacerts, err := utiltls.GenerateKeyAndCertificate("ca", nil, nil, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("kubernetes", cakey, cacerts[0], false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sreClientkey, sreClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-sre", cakey, cacerts[0], false, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sreKubeconfig, err := testKubeconfig(cacerts, sreClientkey, sreClientcerts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serviceClientkey, serviceClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-service", cakey, cacerts[0], false, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serviceKubeconfig, err := testKubeconfig(cacerts, serviceClientkey, serviceClientcerts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l := fakeServer(cacerts, serverkey, servercerts)
|
||||
defer l.Close()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
r func(*http.Request)
|
||||
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient, *cosmosdb.FakePortalDocumentClient)
|
||||
mocks func(*mock_proxy.MockDialer)
|
||||
wantStatusCode int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "success - elevated",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{
|
||||
Elevated: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
openShiftClusterDocument := &api.OpenShiftClusterDocument{
|
||||
ID: resourceID,
|
||||
Key: resourceID,
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
NetworkProfile: api.NetworkProfile{
|
||||
PrivateEndpointIP: privateEndpointIP,
|
||||
},
|
||||
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
|
||||
AROSREKubeconfig: api.SecureBytes(sreKubeconfig),
|
||||
},
|
||||
},
|
||||
}
|
||||
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
},
|
||||
mocks: func(dialer *mock_proxy.MockDialer) {
|
||||
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "GET /test HTTP/1.1\r\nHost: kubernetes:6443\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\nX-Authenticated-Name: system:aro-service\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "success - not elevated",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{},
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
openShiftClusterDocument := &api.OpenShiftClusterDocument{
|
||||
ID: resourceID,
|
||||
Key: resourceID,
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
NetworkProfile: api.NetworkProfile{
|
||||
PrivateEndpointIP: privateEndpointIP,
|
||||
},
|
||||
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
|
||||
AROSREKubeconfig: api.SecureBytes(sreKubeconfig),
|
||||
},
|
||||
},
|
||||
}
|
||||
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
},
|
||||
mocks: func(dialer *mock_proxy.MockDialer) {
|
||||
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "GET /test HTTP/1.1\r\nHost: kubernetes:6443\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\nX-Authenticated-Name: system:aro-sre\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "no auth",
|
||||
r: func(r *http.Request) {
|
||||
r.Header.Del("Authorization")
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: "Forbidden\n",
|
||||
},
|
||||
{
|
||||
name: "bad auth, not uuid",
|
||||
r: func(r *http.Request) {
|
||||
r.Header.Set("Authorization", "Bearer bad")
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: "Forbidden\n",
|
||||
},
|
||||
{
|
||||
name: "bad auth",
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: "Forbidden\n",
|
||||
},
|
||||
{
|
||||
name: "not kubeconfig record",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: "Forbidden\n",
|
||||
},
|
||||
{
|
||||
name: "bad path",
|
||||
r: func(r *http.Request) {
|
||||
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy/test"
|
||||
},
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{},
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "Bad Request\n",
|
||||
},
|
||||
{
|
||||
name: "mismatched path",
|
||||
r: func(r *http.Request) {
|
||||
r.URL.Path = "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/mismatch/kubeconfig/proxy/test"
|
||||
},
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{},
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "Bad Request\n",
|
||||
},
|
||||
{
|
||||
name: "sad portal database",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: "Forbidden\n",
|
||||
},
|
||||
{
|
||||
name: "sad openshift database",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := &api.PortalDocument{
|
||||
ID: token,
|
||||
TTL: 21600,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
Kubeconfig: &api.Kubeconfig{},
|
||||
},
|
||||
}
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
|
||||
openShiftClustersClient.SetError(fmt.Errorf("sad"))
|
||||
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantBody: "Internal Server Error\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dbPortal, portalClient := testdatabase.NewFakePortal()
|
||||
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithOpenShiftClusters(dbOpenShiftClusters).
|
||||
WithPortal(dbPortal)
|
||||
|
||||
checker := testdatabase.NewChecker()
|
||||
|
||||
if tt.fixtureChecker != nil {
|
||||
tt.fixtureChecker(fixture, checker, openShiftClustersClient, portalClient)
|
||||
}
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet,
|
||||
"https://localhost:8444"+resourceID+"/kubeconfig/proxy/test", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
dialer := mock_proxy.NewMockDialer(ctrl)
|
||||
if tt.mocks != nil {
|
||||
tt.mocks(dialer)
|
||||
}
|
||||
|
||||
unauthenticatedRouter := &mux.Router{}
|
||||
|
||||
k := New(logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), nil, nil, dbOpenShiftClusters, dbPortal, dialer, &mux.Router{}, unauthenticatedRouter)
|
||||
|
||||
k.newToken = func() string { return token }
|
||||
|
||||
if tt.r != nil {
|
||||
tt.r(r)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
|
||||
unauthenticatedRouter.ServeHTTP(w, r)
|
||||
|
||||
openShiftClustersClient.SetError(nil)
|
||||
portalClient.SetError(nil)
|
||||
|
||||
for _, err = range checker.CheckOpenShiftClusters(openShiftClustersClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for _, err = range checker.CheckPortals(portalClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
resp := w.Response()
|
||||
|
||||
if resp.StatusCode != tt.wantStatusCode {
|
||||
t.Error(resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
|
||||
t.Error(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(b) != tt.wantBody {
|
||||
t.Errorf("%q", string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,394 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/microsoft"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionName = "session"
|
||||
SessionKeyExpires = "expires"
|
||||
sessionKeyRedirectPath = "redirect_path"
|
||||
sessionKeyState = "state"
|
||||
SessionKeyUsername = "user_name"
|
||||
SessionKeyGroups = "groups"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(time.Time{})
|
||||
}
|
||||
|
||||
// AAD is responsible for ensuring that we have a valid login session with AAD.
|
||||
type AAD interface {
|
||||
AAD(http.Handler) http.Handler
|
||||
Logout(string) http.Handler
|
||||
Redirect(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type oauther interface {
|
||||
AuthCodeURL(string, ...oauth2.AuthCodeOption) string
|
||||
Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
type Verifier interface {
|
||||
Verify(context.Context, string) (oidctoken, error)
|
||||
}
|
||||
|
||||
type idTokenVerifier struct {
|
||||
*oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
func (v *idTokenVerifier) Verify(ctx context.Context, rawIDToken string) (oidctoken, error) {
|
||||
return v.IDTokenVerifier.Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
||||
type oidctoken interface {
|
||||
Claims(interface{}) error
|
||||
}
|
||||
|
||||
func NewVerifier(ctx context.Context, tenantID, clientID string) (Verifier, error) {
|
||||
provider, err := oidc.NewProvider(ctx, "https://login.microsoftonline.com/"+tenantID+"/v2.0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &idTokenVerifier{
|
||||
provider.Verifier(&oidc.Config{
|
||||
ClientID: clientID,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type claims struct {
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
PreferredUsername string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
type aad struct {
|
||||
deploymentMode deployment.Mode
|
||||
log *logrus.Entry
|
||||
now func() time.Time
|
||||
rt http.RoundTripper
|
||||
|
||||
tenantID string
|
||||
clientID string
|
||||
clientKey *rsa.PrivateKey
|
||||
clientCerts []*x509.Certificate
|
||||
|
||||
store *sessions.CookieStore
|
||||
oauther oauther
|
||||
verifier Verifier
|
||||
groupIDs []string
|
||||
|
||||
sessionTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewAAD(deploymentMode deployment.Mode,
|
||||
log *logrus.Entry,
|
||||
baseAccessLog *logrus.Entry,
|
||||
hostname string,
|
||||
sessionKey []byte,
|
||||
tenantID string,
|
||||
clientID string,
|
||||
clientKey *rsa.PrivateKey,
|
||||
clientCerts []*x509.Certificate,
|
||||
groupIDs []string,
|
||||
unauthenticatedRouter *mux.Router,
|
||||
verifier Verifier) (AAD, error) {
|
||||
if len(sessionKey) != 32 {
|
||||
return nil, errors.New("invalid sessionKey")
|
||||
}
|
||||
|
||||
a := &aad{
|
||||
deploymentMode: deploymentMode,
|
||||
log: log,
|
||||
now: time.Now,
|
||||
rt: http.DefaultTransport,
|
||||
|
||||
tenantID: tenantID,
|
||||
clientID: clientID,
|
||||
clientKey: clientKey,
|
||||
clientCerts: clientCerts,
|
||||
store: sessions.NewCookieStore(sessionKey),
|
||||
oauther: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
Endpoint: microsoft.AzureADEndpoint(tenantID),
|
||||
RedirectURL: "https://" + hostname + "/callback",
|
||||
Scopes: []string{
|
||||
"openid",
|
||||
"profile",
|
||||
},
|
||||
},
|
||||
verifier: verifier,
|
||||
groupIDs: groupIDs,
|
||||
|
||||
sessionTimeout: time.Hour,
|
||||
}
|
||||
|
||||
a.store.MaxAge(0)
|
||||
a.store.Options.Secure = true
|
||||
a.store.Options.HttpOnly = true
|
||||
a.store.Options.SameSite = http.SameSiteLaxMode
|
||||
|
||||
unauthenticatedRouter.NewRoute().Methods(http.MethodGet).Path("/callback").Handler(Log(baseAccessLog)(http.HandlerFunc(a.callback)))
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// AAD is the early stage handler which adds a username to the context if it
|
||||
// can. It lets the request through regardless (this is so that failures can be
|
||||
// logged).
|
||||
func (a *aad) AAD(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.store.Get(r, SessionName)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
expires, ok := session.Values[SessionKeyExpires].(time.Time)
|
||||
if !ok || expires.Before(a.now()) {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, ContextKeyUsername, session.Values[SessionKeyUsername])
|
||||
ctx = context.WithValue(ctx, ContextKeyGroups, session.Values[SessionKeyGroups])
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Redirect is the late stage (post logging) handler which redirects to AAD if
|
||||
// there is no valid user.
|
||||
func (a *aad) Redirect(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if ctx.Value(ContextKeyUsername) != nil {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
a.redirect(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *aad) Logout(url string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.store.Get(r, SessionName)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
session.Values = nil
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusSeeOther)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *aad) redirect(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.store.Get(r, SessionName)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if path == "/callback" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
state := uuid.NewV4().String()
|
||||
|
||||
session.Values = map[interface{}]interface{}{
|
||||
sessionKeyRedirectPath: path,
|
||||
sessionKeyState: state,
|
||||
}
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, a.oauther.AuthCodeURL(state), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (a *aad) callback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
session, err := a.store.Get(r, SessionName)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
state, ok := session.Values[sessionKeyState].(string)
|
||||
if !ok {
|
||||
a.redirect(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
delete(session.Values, sessionKeyState)
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if r.FormValue("state") != state {
|
||||
a.internalServerError(w, errors.New("state mismatch"))
|
||||
return
|
||||
}
|
||||
|
||||
if r.FormValue("error") != "" {
|
||||
err := r.FormValue("error")
|
||||
if r.FormValue("error_description") != "" {
|
||||
err += ": " + r.FormValue("error_description")
|
||||
}
|
||||
|
||||
a.internalServerError(w, errors.New(err))
|
||||
return
|
||||
}
|
||||
|
||||
cliCtx := context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
|
||||
Transport: roundtripper.RoundTripperFunc(a.clientAssertion),
|
||||
})
|
||||
|
||||
token, err := a.oauther.Exchange(cliCtx, r.FormValue("code"))
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
a.internalServerError(w, errors.New("id_token not found"))
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := a.verifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var claims claims
|
||||
err = idToken.Claims(&claims)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !GroupsIntersect(a.groupIDs, claims.Groups) {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
}
|
||||
|
||||
redirectPath, ok := session.Values[sessionKeyRedirectPath].(string)
|
||||
if !ok {
|
||||
a.internalServerError(w, errors.New("redirect_path not found"))
|
||||
return
|
||||
}
|
||||
|
||||
delete(session.Values, sessionKeyRedirectPath)
|
||||
session.Values[SessionKeyUsername] = claims.PreferredUsername
|
||||
session.Values[SessionKeyGroups] = claims.Groups
|
||||
session.Values[SessionKeyExpires] = a.now().Add(a.sessionTimeout)
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
a.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, redirectPath, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// clientAssertion adds a JWT client assertion according to
|
||||
// https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials
|
||||
// Treating this as a RoundTripper is more hackery -- this is because the
|
||||
// underlying oauth2 library is a little unextensible.
|
||||
func (a *aad) clientAssertion(req *http.Request) (*http.Response, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, a.tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sp, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, a.clientID, a.clientCerts[0], a.clientKey, "unused")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &adal.ServicePrincipalCertificateSecret{
|
||||
Certificate: a.clientCerts[0],
|
||||
PrivateKey: a.clientKey,
|
||||
}
|
||||
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.SetAuthenticationValues(sp, &req.Form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
form := req.Form.Encode()
|
||||
|
||||
req.Body = ioutil.NopCloser(strings.NewReader(form))
|
||||
req.ContentLength = int64(len(form))
|
||||
|
||||
return a.rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (a *aad) internalServerError(w http.ResponseWriter, err error) {
|
||||
a.log.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func GroupsIntersect(as, bs []string) bool {
|
||||
for _, a := range as {
|
||||
for _, b := range bs {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,851 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/form3tech-oss/jwt-go"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/securecookie"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
|
||||
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
)
|
||||
|
||||
var (
|
||||
clientkey *rsa.PrivateKey
|
||||
clientcerts []*x509.Certificate
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
clientkey, clientcerts, err = utiltls.GenerateKeyAndCertificate("client", nil, nil, false, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type noopOauther struct {
|
||||
tokenMap map[string]interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (noopOauther) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
|
||||
return "authcodeurl"
|
||||
}
|
||||
|
||||
func (o *noopOauther) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
if o.err != nil {
|
||||
return nil, o.err
|
||||
}
|
||||
|
||||
t := oauth2.Token{}
|
||||
return t.WithExtra(o.tokenMap), nil
|
||||
}
|
||||
|
||||
type noopVerifier struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *noopVerifier) Verify(ctx context.Context, rawtoken string) (oidctoken, error) {
|
||||
if v.err != nil {
|
||||
return nil, v.err
|
||||
}
|
||||
return noopClaims(rawtoken), nil
|
||||
}
|
||||
|
||||
type noopClaims []byte
|
||||
|
||||
func (c noopClaims) Claims(v interface{}) error {
|
||||
return json.Unmarshal(c, v)
|
||||
}
|
||||
|
||||
func TestNewAAD(t *testing.T) {
|
||||
_, err := NewAAD(deployment.Production, nil, nil, "", nil, "", "", nil, nil, nil, nil, nil)
|
||||
if err.Error() != "invalid sessionKey" {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAAD(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
request func(*aad) (*http.Request, error)
|
||||
wantStatusCode int
|
||||
wantAuthenticated bool
|
||||
wantUsername string
|
||||
wantGroups []string
|
||||
}{
|
||||
{
|
||||
name: "authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
SessionKeyUsername: "username",
|
||||
SessionKeyGroups: []string{"group1", "group2"},
|
||||
SessionKeyExpires: time.Unix(1, 0),
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantAuthenticated: true,
|
||||
wantUsername: "username",
|
||||
wantGroups: []string{"group1", "group2"},
|
||||
},
|
||||
{
|
||||
name: "expired - not authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
SessionKeyUsername: "username",
|
||||
SessionKeyGroups: []string{"group1", "group2"},
|
||||
SessionKeyExpires: time.Unix(-1, 0),
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no cookie - not authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid cookie",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{"session=xxx"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.(*aad).now = func() time.Time { return time.Unix(0, 0) }
|
||||
|
||||
var username string
|
||||
var usernameok bool
|
||||
var groups []string
|
||||
var groupsok bool
|
||||
h := a.AAD(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, usernameok = r.Context().Value(ContextKeyUsername).(string)
|
||||
groups, groupsok = r.Context().Value(ContextKeyGroups).([]string)
|
||||
}))
|
||||
|
||||
r, err := tt.request(a.(*aad))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if tt.wantStatusCode != 0 && w.Code != tt.wantStatusCode {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if username != tt.wantUsername {
|
||||
t.Error(username)
|
||||
}
|
||||
if usernameok != tt.wantAuthenticated {
|
||||
t.Error(usernameok)
|
||||
}
|
||||
if !reflect.DeepEqual(groups, tt.wantGroups) {
|
||||
t.Error(groups)
|
||||
}
|
||||
if groupsok != tt.wantAuthenticated {
|
||||
t.Error(groupsok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
request func(*aad) (*http.Request, error)
|
||||
wantStatusCode int
|
||||
wantAuthenticated bool
|
||||
}{
|
||||
{
|
||||
name: "authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, ContextKeyUsername, "user")
|
||||
return http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantAuthenticated: true,
|
||||
},
|
||||
{
|
||||
name: "not authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
ctx := context.Background()
|
||||
return http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
|
||||
},
|
||||
wantStatusCode: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "not authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
ctx := context.Background()
|
||||
return http.NewRequestWithContext(ctx, http.MethodGet, "/callback", nil)
|
||||
},
|
||||
wantStatusCode: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "invalid cookie",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{"session=xxx"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var authenticated bool
|
||||
h := a.Redirect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, authenticated = r.Context().Value(ContextKeyUsername).(string)
|
||||
}))
|
||||
|
||||
r, err := tt.request(a.(*aad))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != tt.wantStatusCode {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if authenticated != tt.wantAuthenticated {
|
||||
t.Fatal(authenticated)
|
||||
}
|
||||
|
||||
if tt.wantStatusCode == http.StatusInternalServerError {
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantAuthenticated {
|
||||
if !strings.HasPrefix(w.Header().Get("Location"), "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=&redirect_uri=https%3A%2F%2F%2Fcallback&response_type=code&scope=openid+profile&state=") {
|
||||
t.Error(w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
var m map[interface{}]interface{}
|
||||
cookies := w.Result().Cookies()
|
||||
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(m) != 2 {
|
||||
t.Error(len(m))
|
||||
}
|
||||
|
||||
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
|
||||
redirectPath != "/" {
|
||||
t.Error(m[sessionKeyRedirectPath])
|
||||
}
|
||||
|
||||
if state, ok := m[sessionKeyState].(string); !ok ||
|
||||
uuid.FromStringOrNil(state) == uuid.Nil {
|
||||
t.Error(m[sessionKeyState])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
request func(*aad) (*http.Request, error)
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
SessionKeyUsername: "username",
|
||||
SessionKeyGroups: []string{"group1", "group2"},
|
||||
SessionKeyExpires: time.Unix(1, 0),
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantStatusCode: http.StatusSeeOther,
|
||||
},
|
||||
{
|
||||
name: "no cookie - not authenticated",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
}, nil
|
||||
},
|
||||
wantStatusCode: http.StatusSeeOther,
|
||||
},
|
||||
{
|
||||
name: "invalid cookie",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{"session=xxx"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
h := a.Logout("/bye")
|
||||
|
||||
r, err := tt.request(a.(*aad))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != tt.wantStatusCode {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if tt.wantStatusCode == http.StatusInternalServerError {
|
||||
return
|
||||
}
|
||||
|
||||
if w.Header().Get("Location") != "/bye" {
|
||||
t.Error(w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
var m map[interface{}]interface{}
|
||||
cookies := w.Result().Cookies()
|
||||
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(m) != 0 {
|
||||
t.Error(len(m))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
clientID := "00000000-0000-0000-0000-000000000000"
|
||||
groups := []string{
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
username := "user"
|
||||
|
||||
idToken, err := json.Marshal(claims{
|
||||
Groups: groups,
|
||||
PreferredUsername: username,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
request func(*aad) (*http.Request, error)
|
||||
oauther oauther
|
||||
verifier Verifier
|
||||
wantAuthenticated bool
|
||||
wantError string
|
||||
wantForbidden bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
tokenMap: map[string]interface{}{
|
||||
"id_token": string(idToken),
|
||||
},
|
||||
},
|
||||
verifier: &noopVerifier{},
|
||||
wantAuthenticated: true,
|
||||
},
|
||||
{
|
||||
name: "fail - invalid cookie",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Cookie": []string{"session=xxx"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - corrupt sessionKeyState",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{},
|
||||
},
|
||||
{
|
||||
name: "fail - state mismatch",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{"bad"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - error returned",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
"error": []string{"bad things happened."},
|
||||
"error_description": []string{"really bad things."},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - oauther failed",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
err: fmt.Errorf("failed"),
|
||||
},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - no idtoken",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - verifier error",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
tokenMap: map[string]interface{}{"id_token": ""},
|
||||
},
|
||||
verifier: &noopVerifier{
|
||||
err: fmt.Errorf("failed"),
|
||||
},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - invalid claims",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
tokenMap: map[string]interface{}{
|
||||
"id_token": "",
|
||||
},
|
||||
},
|
||||
verifier: &noopVerifier{},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "fail - group mismatch",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
sessionKeyRedirectPath: "/",
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
tokenMap: map[string]interface{}{
|
||||
"id_token": "null",
|
||||
},
|
||||
},
|
||||
verifier: &noopVerifier{},
|
||||
wantForbidden: true,
|
||||
},
|
||||
{
|
||||
name: "fail - missing redirect_path",
|
||||
request: func(a *aad) (*http.Request, error) {
|
||||
uuid := uuid.NewV4().String()
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
|
||||
sessionKeyState: uuid,
|
||||
}, a.store.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
URL: &url.URL{},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{SessionName + "=" + cookie},
|
||||
},
|
||||
Form: url.Values{
|
||||
"state": []string{uuid},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
oauther: &noopOauther{
|
||||
tokenMap: map[string]interface{}{
|
||||
"id_token": string(idToken),
|
||||
},
|
||||
},
|
||||
verifier: &noopVerifier{},
|
||||
wantError: "Internal Server Error\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", clientID, clientkey, clientcerts, groups, mux.NewRouter(), tt.verifier)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.(*aad).now = func() time.Time { return time.Unix(0, 0) }
|
||||
a.(*aad).oauther = tt.oauther
|
||||
|
||||
r, err := tt.request(a.(*aad))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
a.(*aad).callback(w, r)
|
||||
|
||||
if tt.wantError != "" {
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if w.Body.String() != tt.wantError {
|
||||
t.Error(w.Body.String())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var m map[interface{}]interface{}
|
||||
cookies := w.Result().Cookies()
|
||||
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case tt.wantAuthenticated:
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Location") != "/" {
|
||||
t.Error(w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
if len(m) != 3 {
|
||||
t.Error(len(m))
|
||||
}
|
||||
|
||||
if expires, ok := m[SessionKeyExpires].(time.Time); !ok ||
|
||||
expires != time.Unix(3600, 0) {
|
||||
t.Error(m[SessionKeyExpires])
|
||||
}
|
||||
|
||||
if sessionGroups, ok := m[SessionKeyGroups].([]string); !ok ||
|
||||
!reflect.DeepEqual(sessionGroups, groups) {
|
||||
t.Error(m[SessionKeyGroups])
|
||||
}
|
||||
|
||||
if sessionUsername, ok := m[SessionKeyUsername].(string); !ok ||
|
||||
sessionUsername != username {
|
||||
t.Error(m[SessionKeyUsername])
|
||||
}
|
||||
|
||||
case tt.wantForbidden:
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Location") != "/" {
|
||||
t.Error(w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
if len(m) != 1 {
|
||||
t.Error(len(m))
|
||||
}
|
||||
|
||||
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
|
||||
redirectPath != "/" {
|
||||
t.Error(m[sessionKeyRedirectPath])
|
||||
}
|
||||
|
||||
default:
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Location") != "/authcodeurl" {
|
||||
t.Error(w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
if len(m) != 2 {
|
||||
t.Error(len(m))
|
||||
}
|
||||
|
||||
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
|
||||
redirectPath != "" {
|
||||
t.Error(m[sessionKeyRedirectPath])
|
||||
}
|
||||
|
||||
if state, ok := m[sessionKeyState].(string); !ok ||
|
||||
uuid.FromStringOrNil(state) == uuid.Nil {
|
||||
t.Error(m[sessionKeyState])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAssertion(t *testing.T) {
|
||||
clientID := "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", clientID, clientkey, clientcerts, nil, mux.NewRouter(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.(*aad).rt = roundtripper.RoundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://localhost/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Form = url.Values{"test": []string{"value"}}
|
||||
|
||||
_, err = a.(*aad).clientAssertion(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if req.Form.Get("test") != "value" {
|
||||
t.Error(req.Form.Get("test"))
|
||||
}
|
||||
|
||||
if req.Form.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
|
||||
t.Error(req.Form.Get("client_assertion_type"))
|
||||
}
|
||||
|
||||
p := &jwt.Parser{}
|
||||
_, err = p.Parse(req.Form.Get("client_assertion"), func(*jwt.Token) (interface{}, error) {
|
||||
return &clientkey.PublicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
uuid "github.com/satori/go.uuid"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
)
|
||||
|
||||
// Bearer validates a Bearer token and adds the corresponding username to the
|
||||
// context if it checks out. It lets the request through regardless (this is so
|
||||
// that failures can be logged).
|
||||
func Bearer(dbPortal database.Portal) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
authorization := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authorization, "Bearer ") {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := uuid.FromString(strings.TrimPrefix(authorization, "Bearer "))
|
||||
if err != nil {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
portalDoc, err := dbPortal.Get(ctx, token.String())
|
||||
if err != nil {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ContextKeyUsername, portalDoc.Portal.Username)
|
||||
ctx = context.WithValue(ctx, ContextKeyPortalDoc, portalDoc)
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
)
|
||||
|
||||
func TestBearer(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
fixture func(*testdatabase.Fixture)
|
||||
request func() (*http.Request, error)
|
||||
wantAuthenticated bool
|
||||
wantUsername string
|
||||
}{
|
||||
{
|
||||
name: "authenticated",
|
||||
fixture: func(fixture *testdatabase.Fixture) {
|
||||
fixture.AddPortalDocuments(&api.PortalDocument{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
Portal: &api.Portal{
|
||||
Username: "username",
|
||||
},
|
||||
})
|
||||
},
|
||||
request: func() (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": []string{"Bearer 00000000-0000-0000-0000-000000000000"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
wantAuthenticated: true,
|
||||
wantUsername: "username",
|
||||
},
|
||||
{
|
||||
name: "not authenticated - no header",
|
||||
request: func() (*http.Request, error) {
|
||||
return &http.Request{}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not authenticated - bad header",
|
||||
request: func() (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": []string{"Bearer bad"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not authenticated - berarer not found",
|
||||
fixture: func(fixture *testdatabase.Fixture) {
|
||||
fixture.AddPortalDocuments(&api.PortalDocument{
|
||||
ID: "00000000-0000-0000-0000-000000000000",
|
||||
Portal: &api.Portal{
|
||||
Username: "username",
|
||||
},
|
||||
})
|
||||
},
|
||||
request: func() (*http.Request, error) {
|
||||
return &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": []string{"Bearer 10000000-0000-0000-0000-000000000000"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
} {
|
||||
dbPortal, _ := testdatabase.NewFakePortal()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithPortal(dbPortal)
|
||||
|
||||
if tt.fixture != nil {
|
||||
tt.fixture(fixture)
|
||||
}
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var username string
|
||||
var usernameok bool
|
||||
var portaldoc *api.PortalDocument
|
||||
var portaldocok bool
|
||||
h := Bearer(dbPortal)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, usernameok = r.Context().Value(ContextKeyUsername).(string)
|
||||
portaldoc, portaldocok = r.Context().Value(ContextKeyPortalDoc).(*api.PortalDocument)
|
||||
}))
|
||||
|
||||
r, err := tt.request()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if username != tt.wantUsername {
|
||||
t.Error(username)
|
||||
}
|
||||
if usernameok != tt.wantAuthenticated {
|
||||
t.Error(usernameok)
|
||||
}
|
||||
if portaldocok != tt.wantAuthenticated {
|
||||
t.Error(portaldocok)
|
||||
}
|
||||
if tt.wantAuthenticated && portaldoc.Portal.Username != username {
|
||||
t.Error(portaldoc.Portal.Username)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
utillog "github.com/Azure/ARO-RP/pkg/util/log"
|
||||
)
|
||||
|
||||
type logResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
statusCode int
|
||||
bytes int
|
||||
}
|
||||
|
||||
func (w *logResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker := w.ResponseWriter.(http.Hijacker)
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (w *logResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
w.bytes += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *logResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
|
||||
type logReadCloser struct {
|
||||
io.ReadCloser
|
||||
|
||||
bytes int
|
||||
}
|
||||
|
||||
func (rc *logReadCloser) Read(b []byte) (int, error) {
|
||||
n, err := rc.ReadCloser.Read(b)
|
||||
rc.bytes += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func Log(baseLog *logrus.Entry) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t := time.Now()
|
||||
|
||||
r.Body = &logReadCloser{ReadCloser: r.Body}
|
||||
w = &logResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
log := baseLog
|
||||
log = utillog.EnrichWithPath(log, r.URL.Path)
|
||||
|
||||
username, _ := r.Context().Value(ContextKeyUsername).(string)
|
||||
|
||||
log = log.WithFields(logrus.Fields{
|
||||
"request_method": r.Method,
|
||||
"request_path": r.URL.Path,
|
||||
"request_proto": r.Proto,
|
||||
"request_remote_addr": r.RemoteAddr,
|
||||
"request_user_agent": r.UserAgent(),
|
||||
"username": username,
|
||||
})
|
||||
log.Print("read request")
|
||||
|
||||
defer func() {
|
||||
log.WithFields(logrus.Fields{
|
||||
"body_read_bytes": r.Body.(*logReadCloser).bytes,
|
||||
"body_written_bytes": w.(*logResponseWriter).bytes,
|
||||
"duration": time.Since(t).Seconds(),
|
||||
"response_status_code": w.(*logResponseWriter).statusCode,
|
||||
}).Print("sent response")
|
||||
}()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
testlog "github.com/Azure/ARO-RP/test/util/log"
|
||||
)
|
||||
|
||||
func TestLog(t *testing.T) {
|
||||
h, log := testlog.New()
|
||||
|
||||
ctx := context.WithValue(context.Background(), ContextKeyUsername, "username")
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://localhost/", strings.NewReader("body"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.RemoteAddr = "127.0.0.1:1234"
|
||||
r.Header.Set("User-Agent", "user-agent")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
Log(log)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL = nil // mutate the request
|
||||
|
||||
_ = w.(http.Hijacker) // must implement http.Hijacker
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(w, r.Body)
|
||||
})).ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
expected := []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"msg": gomega.Equal("read request"),
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"request_method": gomega.Equal("POST"),
|
||||
"request_path": gomega.Equal("/"),
|
||||
"request_proto": gomega.Equal("HTTP/1.1"),
|
||||
"request_remote_addr": gomega.Equal("127.0.0.1:1234"),
|
||||
"request_user_agent": gomega.Equal("user-agent"),
|
||||
"username": gomega.Equal("username"),
|
||||
},
|
||||
{
|
||||
"msg": gomega.Equal("sent response"),
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"body_read_bytes": gomega.Equal(4),
|
||||
"body_written_bytes": gomega.Equal(4),
|
||||
"response_status_code": gomega.Equal(http.StatusOK),
|
||||
"request_method": gomega.Equal("POST"),
|
||||
"request_path": gomega.Equal("/"),
|
||||
"request_proto": gomega.Equal("HTTP/1.1"),
|
||||
"request_remote_addr": gomega.Equal("127.0.0.1:1234"),
|
||||
"request_user_agent": gomega.Equal("user-agent"),
|
||||
"username": gomega.Equal("username"),
|
||||
},
|
||||
}
|
||||
|
||||
err = testlog.AssertLoggingOutput(h, expected)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for _, e := range h.Entries {
|
||||
fmt.Println(e)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
ContextKeyUsername contextKey = iota
|
||||
ContextKeyGroups
|
||||
ContextKeyPortalDoc
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Panic(log *logrus.Entry) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
log.Errorf("panic: %#v\n%s\n", e, string(debug.Stack()))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package middleware
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
testlog "github.com/Azure/ARO-RP/test/util/log"
|
||||
)
|
||||
|
||||
func TestPanic(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
panictext string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
},
|
||||
{
|
||||
name: "panic",
|
||||
panictext: "random error",
|
||||
},
|
||||
} {
|
||||
h, log := testlog.New()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
Panic(log)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if tt.panictext != "" {
|
||||
panic(tt.panictext)
|
||||
}
|
||||
})).ServeHTTP(w, nil)
|
||||
|
||||
var expected []map[string]types.GomegaMatcher
|
||||
if tt.panictext == "" {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
} else {
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Error(w.Code)
|
||||
}
|
||||
|
||||
expected = []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"msg": gomega.MatchRegexp(regexp.QuoteMeta(tt.panictext)),
|
||||
"level": gomega.Equal(logrus.ErrorLevel),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
err := testlog.AssertLoggingOutput(h, expected)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,286 @@
|
|||
package portal
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
frontendmiddleware "github.com/Azure/ARO-RP/pkg/frontend/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/kubeconfig"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/prometheus"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/ssh"
|
||||
"github.com/Azure/ARO-RP/pkg/proxy"
|
||||
)
|
||||
|
||||
type Runnable interface {
|
||||
Run(context.Context) error
|
||||
}
|
||||
|
||||
type portal struct {
|
||||
env env.Core
|
||||
log *logrus.Entry
|
||||
baseAccessLog *logrus.Entry
|
||||
l net.Listener
|
||||
sshl net.Listener
|
||||
verifier middleware.Verifier
|
||||
|
||||
hostname string
|
||||
servingKey *rsa.PrivateKey
|
||||
servingCerts []*x509.Certificate
|
||||
clientID string
|
||||
clientKey *rsa.PrivateKey
|
||||
clientCerts []*x509.Certificate
|
||||
sessionKey []byte
|
||||
sshKey *rsa.PrivateKey
|
||||
|
||||
groupIDs []string
|
||||
elevatedGroupIDs []string
|
||||
|
||||
dbPortal database.Portal
|
||||
dbOpenShiftClusters database.OpenShiftClusters
|
||||
|
||||
dialer proxy.Dialer
|
||||
|
||||
t *template.Template
|
||||
|
||||
aad middleware.AAD
|
||||
}
|
||||
|
||||
func NewPortal(env env.Core,
|
||||
log *logrus.Entry,
|
||||
baseAccessLog *logrus.Entry,
|
||||
l net.Listener,
|
||||
sshl net.Listener,
|
||||
verifier middleware.Verifier,
|
||||
hostname string,
|
||||
servingKey *rsa.PrivateKey,
|
||||
servingCerts []*x509.Certificate,
|
||||
clientID string,
|
||||
clientKey *rsa.PrivateKey,
|
||||
clientCerts []*x509.Certificate,
|
||||
sessionKey []byte,
|
||||
sshKey *rsa.PrivateKey,
|
||||
groupIDs []string,
|
||||
elevatedGroupIDs []string,
|
||||
dbOpenShiftClusters database.OpenShiftClusters,
|
||||
dbPortal database.Portal,
|
||||
dialer proxy.Dialer) Runnable {
|
||||
return &portal{
|
||||
env: env,
|
||||
log: log,
|
||||
baseAccessLog: baseAccessLog,
|
||||
l: l,
|
||||
sshl: sshl,
|
||||
verifier: verifier,
|
||||
|
||||
hostname: hostname,
|
||||
servingKey: servingKey,
|
||||
servingCerts: servingCerts,
|
||||
clientID: clientID,
|
||||
clientKey: clientKey,
|
||||
clientCerts: clientCerts,
|
||||
sessionKey: sessionKey,
|
||||
sshKey: sshKey,
|
||||
|
||||
groupIDs: groupIDs,
|
||||
elevatedGroupIDs: elevatedGroupIDs,
|
||||
|
||||
dbOpenShiftClusters: dbOpenShiftClusters,
|
||||
dbPortal: dbPortal,
|
||||
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *portal) Run(ctx context.Context) error {
|
||||
asset, err := Asset("index.html")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.t, err = template.New("index.html").Parse(string(asset))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
{
|
||||
PrivateKey: p.servingKey,
|
||||
},
|
||||
},
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
SessionTicketsDisabled: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.CurveP256,
|
||||
tls.X25519,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cert := range p.servingCerts {
|
||||
config.Certificates[0].Certificate = append(config.Certificates[0].Certificate, cert.Raw)
|
||||
}
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.Use(middleware.Panic(p.log))
|
||||
|
||||
unauthenticatedRouter := r.NewRoute().Subrouter()
|
||||
p.unauthenticatedRoutes(unauthenticatedRouter)
|
||||
|
||||
p.aad, err = middleware.NewAAD(p.env.DeploymentMode(), p.log, p.baseAccessLog, p.hostname, p.sessionKey, p.env.TenantID(), p.clientID, p.clientKey, p.clientCerts, p.groupIDs, unauthenticatedRouter, p.verifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aadAuthenticatedRouter := r.NewRoute().Subrouter()
|
||||
aadAuthenticatedRouter.Use(p.aad.AAD)
|
||||
aadAuthenticatedRouter.Use(middleware.Log(p.baseAccessLog))
|
||||
aadAuthenticatedRouter.Use(p.aad.Redirect)
|
||||
aadAuthenticatedRouter.Use(csrf.Protect(p.sessionKey, csrf.SameSite(csrf.SameSiteStrictMode), csrf.MaxAge(0)))
|
||||
|
||||
p.aadAuthenticatedRoutes(aadAuthenticatedRouter)
|
||||
|
||||
ssh, err := ssh.New(p.env, p.log, p.baseAccessLog, p.sshl, p.sshKey, p.elevatedGroupIDs, p.dbOpenShiftClusters, p.dbPortal, p.dialer, aadAuthenticatedRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ssh.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kubeconfig.New(p.log, p.baseAccessLog, p.servingCerts[0], p.elevatedGroupIDs, p.dbOpenShiftClusters, p.dbPortal, p.dialer, aadAuthenticatedRouter, unauthenticatedRouter)
|
||||
|
||||
prometheus.New(p.log, p.dbOpenShiftClusters, p.dialer, aadAuthenticatedRouter)
|
||||
|
||||
s := &http.Server{
|
||||
Handler: frontendmiddleware.Lowercase(r),
|
||||
ReadTimeout: 10 * time.Second,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
ErrorLog: log.New(p.log.Writer(), "", 0),
|
||||
BaseContext: func(net.Listener) context.Context { return ctx },
|
||||
}
|
||||
|
||||
return s.Serve(tls.NewListener(p.l, config))
|
||||
}
|
||||
|
||||
func (p *portal) unauthenticatedRoutes(r *mux.Router) {
|
||||
logger := middleware.Log(p.baseAccessLog)
|
||||
|
||||
r.NewRoute().Methods(http.MethodGet).Path("/healthz/ready").Handler(logger(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})))
|
||||
}
|
||||
|
||||
func (p *portal) aadAuthenticatedRoutes(r *mux.Router) {
|
||||
for _, name := range AssetNames() {
|
||||
if name == "index.html" {
|
||||
continue
|
||||
}
|
||||
|
||||
r.NewRoute().Methods(http.MethodGet).Path("/" + name).HandlerFunc(p.serve(name))
|
||||
}
|
||||
|
||||
r.NewRoute().Methods(http.MethodGet).Path("/").HandlerFunc(p.index)
|
||||
|
||||
r.NewRoute().Methods(http.MethodGet).Path("/api/clusters").HandlerFunc(p.clusters)
|
||||
r.NewRoute().Methods(http.MethodPost).Path("/api/logout").Handler(p.aad.Logout("/"))
|
||||
}
|
||||
|
||||
func (p *portal) serve(path string) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := Asset(path)
|
||||
if err != nil {
|
||||
p.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(b))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *portal) index(w http.ResponseWriter, r *http.Request) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
err := p.t.ExecuteTemplate(buf, "index.html", map[string]interface{}{
|
||||
"location": p.env.Location(),
|
||||
csrf.TemplateTag: csrf.TemplateField(r),
|
||||
})
|
||||
if err != nil {
|
||||
p.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, "index.html", time.Time{}, bytes.NewReader(buf.Bytes()))
|
||||
}
|
||||
|
||||
func (p *portal) clusters(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
docs, err := p.dbOpenShiftClusters.ListAll(ctx)
|
||||
if err != nil {
|
||||
p.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
clusters := make([]string, 0, len(docs.OpenShiftClusterDocuments))
|
||||
for _, doc := range docs.OpenShiftClusterDocuments {
|
||||
ps := doc.OpenShiftCluster.Properties.ProvisioningState
|
||||
fps := doc.OpenShiftCluster.Properties.FailedProvisioningState
|
||||
|
||||
switch {
|
||||
case ps == api.ProvisioningStateCreating,
|
||||
ps == api.ProvisioningStateDeleting,
|
||||
ps == api.ProvisioningStateFailed &&
|
||||
(fps == api.ProvisioningStateCreating ||
|
||||
fps == api.ProvisioningStateDeleting):
|
||||
default:
|
||||
clusters = append(clusters, doc.OpenShiftCluster.ID)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(clusters)
|
||||
|
||||
b, err := json.MarshalIndent(clusters, "", " ")
|
||||
if err != nil {
|
||||
p.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
func (p *portal) internalServerError(w http.ResponseWriter, err error) {
|
||||
p.log.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package portal
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
)
|
||||
|
||||
func TestClusters(t *testing.T) {
|
||||
dbOpenShiftClusters, _ := testdatabase.NewFakeOpenShiftClusters()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithOpenShiftClusters(dbOpenShiftClusters)
|
||||
|
||||
fixture.AddOpenShiftClusterDocuments(&api.OpenShiftClusterDocument{
|
||||
Key: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourcegroupname/providers/microsoft.redhatopenshift/openshiftclusters/succeeded",
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
ID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/succeeded",
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
ProvisioningState: api.ProvisioningStateSucceeded,
|
||||
},
|
||||
},
|
||||
}, &api.OpenShiftClusterDocument{
|
||||
Key: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourcegroupname/providers/microsoft.redhatopenshift/openshiftclusters/creating",
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
ID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/creating",
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
ProvisioningState: api.ProvisioningStateCreating,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &portal{
|
||||
dbOpenShiftClusters: dbOpenShiftClusters,
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.clusters(w, &http.Request{})
|
||||
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Error(w.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
var r []string
|
||||
err = json.NewDecoder(w.Body).Decode(&r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(r, []string{"/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/succeeded"}) {
|
||||
t.Error(r)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package prometheus
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/clientcache"
|
||||
"github.com/Azure/ARO-RP/pkg/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
|
||||
)
|
||||
|
||||
type prometheus struct {
|
||||
log *logrus.Entry
|
||||
|
||||
dbOpenShiftClusters database.OpenShiftClusters
|
||||
|
||||
dialer proxy.Dialer
|
||||
clientCache clientcache.ClientCache
|
||||
}
|
||||
|
||||
func New(baseLog *logrus.Entry,
|
||||
dbOpenShiftClusters database.OpenShiftClusters,
|
||||
dialer proxy.Dialer,
|
||||
aadAuthenticatedRouter *mux.Router) *prometheus {
|
||||
p := &prometheus{
|
||||
log: baseLog,
|
||||
|
||||
dbOpenShiftClusters: dbOpenShiftClusters,
|
||||
|
||||
dialer: dialer,
|
||||
clientCache: clientcache.New(time.Hour),
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: p.director,
|
||||
Transport: roundtripper.RoundTripperFunc(p.roundTripper),
|
||||
ModifyResponse: p.modifyResponse,
|
||||
ErrorLog: log.New(p.log.Writer(), "", 0),
|
||||
}
|
||||
|
||||
aadAuthenticatedRouter.NewRoute().Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/prometheus").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Path += "/"
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
})
|
||||
aadAuthenticatedRouter.NewRoute().PathPrefix("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/prometheus/").Handler(rp)
|
||||
|
||||
return p
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
package prometheus
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api/validate"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
"github.com/Azure/ARO-RP/pkg/util/portforward"
|
||||
"github.com/Azure/ARO-RP/pkg/util/restconfig"
|
||||
)
|
||||
|
||||
// Unfortunately the signature of httputil.ReverseProxy.Director does not allow
|
||||
// us to return errors. We get around this limitation slightly naughtily by
|
||||
// storing return information in the request context.
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyClient contextKey = iota
|
||||
contextKeyResponse
|
||||
)
|
||||
|
||||
func (p *prometheus) director(r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
|
||||
if !validate.RxClusterID.MatchString(resourceID) {
|
||||
p.error(r, http.StatusBadRequest, nil)
|
||||
return
|
||||
}
|
||||
|
||||
cli := p.clientCache.Get(resourceID)
|
||||
if cli == nil {
|
||||
var err error
|
||||
cli, err = p.cli(ctx, resourceID)
|
||||
if err != nil {
|
||||
p.error(r, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.clientCache.Put(resourceID, cli)
|
||||
}
|
||||
|
||||
r.RequestURI = ""
|
||||
r.URL.Scheme = "http"
|
||||
r.URL.Host = "prometheus-k8s-0:9090"
|
||||
r.URL.Path = "/" + strings.Join(strings.Split(r.URL.Path, "/")[10:], "/")
|
||||
r.Header.Del("Cookie")
|
||||
r.Header.Del("Referer")
|
||||
r.Host = r.URL.Host
|
||||
|
||||
// http.Request.WithContext returns a copy of the original Request with the
|
||||
// new context, but we have no way to return it, so we overwrite our
|
||||
// existing request.
|
||||
*r = *r.WithContext(context.WithValue(ctx, contextKeyClient, cli))
|
||||
}
|
||||
|
||||
func (p *prometheus) cli(ctx context.Context, resourceID string) (*http.Client, error) {
|
||||
openShiftDoc, err := p.dbOpenShiftClusters.Get(ctx, resourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restconfig, err := restconfig.RestConfig(p.dialer, openShiftDoc.OpenShiftCluster)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return portforward.DialContext(ctx, p.log, restconfig, "openshift-monitoring", "prometheus-k8s-0", "9090")
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *prometheus) roundTripper(r *http.Request) (*http.Response, error) {
|
||||
if resp, ok := r.Context().Value(contextKeyResponse).(*http.Response); ok {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
cli := r.Context().Value(contextKeyClient).(*http.Client)
|
||||
return cli.Do(r)
|
||||
}
|
||||
|
||||
func (p *prometheus) modifyResponse(r *http.Response) error {
|
||||
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if mediaType != "text/html" {
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
n, err := html.Parse(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
buf.Write(b)
|
||||
|
||||
} else {
|
||||
walk(n, makeRelative)
|
||||
|
||||
err = html.Render(buf, n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10))
|
||||
}
|
||||
|
||||
r.Body = ioutil.NopCloser(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeRelative(n *html.Node) {
|
||||
switch n.DataAtom {
|
||||
case atom.A, atom.Link:
|
||||
for i, attr := range n.Attr {
|
||||
if attr.Namespace == "" && attr.Key == "href" && strings.HasPrefix(n.Attr[i].Val, "/") {
|
||||
n.Attr[i].Val = "." + n.Attr[i].Val
|
||||
}
|
||||
}
|
||||
case atom.Script:
|
||||
for i, attr := range n.Attr {
|
||||
if attr.Namespace == "" && attr.Key == "src" && strings.HasPrefix(n.Attr[i].Val, "/") {
|
||||
n.Attr[i].Val = "." + n.Attr[i].Val
|
||||
}
|
||||
}
|
||||
|
||||
if len(n.Attr) == 0 {
|
||||
n.FirstChild.Data = strings.Replace(n.FirstChild.Data, `var PATH_PREFIX = "";`, `var PATH_PREFIX = ".";`, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func walk(n *html.Node, f func(*html.Node)) {
|
||||
f(n)
|
||||
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
walk(c, f)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prometheus) error(r *http.Request, statusCode int, err error) {
|
||||
if err != nil {
|
||||
p.log.Print(err)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
|
||||
*r = *r.WithContext(context.WithValue(r.Context(), contextKeyResponse, w.Response()))
|
||||
}
|
|
@ -0,0 +1,375 @@
|
|||
package prometheus
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/portforward"
|
||||
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
"github.com/Azure/ARO-RP/test/util/listener"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
net.Conn
|
||||
rw *bufio.ReadWriter
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (int, error) {
|
||||
return c.rw.Read(b)
|
||||
}
|
||||
|
||||
// fakeServer returns a test listener for an HTTPS server which validates its
|
||||
// client and forwards a SPDY request to a second HTTP server which echos back
|
||||
// the request it received.
|
||||
func fakeServer(cacerts []*x509.Certificate, serverkey *rsa.PrivateKey, servercerts []*x509.Certificate) *listener.Listener {
|
||||
podl := listener.NewListener()
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(podl, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, _ := httputil.DumpRequest(r, true)
|
||||
_, _ = w.Write(b)
|
||||
}))
|
||||
}()
|
||||
|
||||
kubel := listener.NewListener()
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cacerts[0])
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(tls.NewListener(kubel, &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
{
|
||||
Certificate: [][]byte{
|
||||
servercerts[0].Raw,
|
||||
},
|
||||
PrivateKey: serverkey,
|
||||
},
|
||||
},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
}), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
|
||||
w.Header().Add(httpstream.HeaderUpgrade, spdy.HeaderSpdy31)
|
||||
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
hijacker := w.(http.Hijacker)
|
||||
c, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var dataStream, errorStream httpstream.Stream
|
||||
var dataReplySent, errorReplySent <-chan struct{}
|
||||
|
||||
var serverconn httpstream.Connection
|
||||
serverconn, err = spdy.NewServerConnection(&conn{Conn: c, rw: buf}, func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
switch stream.Headers().Get(corev1.StreamType) {
|
||||
case corev1.StreamTypeData:
|
||||
if dataStream != nil {
|
||||
return fmt.Errorf("dataStream already set")
|
||||
}
|
||||
dataStream = stream
|
||||
dataReplySent = replySent
|
||||
case corev1.StreamTypeError:
|
||||
if errorStream != nil {
|
||||
return fmt.Errorf("errorStream already set")
|
||||
}
|
||||
errorStream = stream
|
||||
errorReplySent = replySent
|
||||
}
|
||||
|
||||
if dataStream != nil && errorStream != nil {
|
||||
go func() {
|
||||
<-dataReplySent
|
||||
<-errorReplySent
|
||||
|
||||
podl.Enqueue(portforward.NewStreamConn(nil, serverconn, dataStream, errorStream))
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}))
|
||||
|
||||
podl.Close()
|
||||
}()
|
||||
|
||||
return kubel
|
||||
}
|
||||
|
||||
func testKubeconfig(cacerts []*x509.Certificate, clientkey *rsa.PrivateKey, clientcerts []*x509.Certificate) ([]byte, error) {
|
||||
kc := &v1.Config{
|
||||
Clusters: []v1.NamedCluster{
|
||||
{
|
||||
Cluster: v1.Cluster{
|
||||
Server: "https://kubernetes:6443",
|
||||
},
|
||||
},
|
||||
},
|
||||
AuthInfos: []v1.NamedAuthInfo{
|
||||
{},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
kc.AuthInfos[0].AuthInfo.ClientKeyData, err = utiltls.PrivateKeyAsBytes(clientkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kc.AuthInfos[0].AuthInfo.ClientCertificateData, err = utiltls.CertAsBytes(clientcerts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kc.Clusters[0].Cluster.CertificateAuthorityData, err = utiltls.CertAsBytes(cacerts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(kc)
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
|
||||
privateEndpointIP := "1.2.3.4"
|
||||
|
||||
cakey, cacerts, err := utiltls.GenerateKeyAndCertificate("ca", nil, nil, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("kubernetes", cakey, cacerts[0], false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serviceClientkey, serviceClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-service", cakey, cacerts[0], false, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serviceKubeconfig, err := testKubeconfig(cacerts, serviceClientkey, serviceClientcerts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l := fakeServer(cacerts, serverkey, servercerts)
|
||||
defer l.Close()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
r func(*http.Request)
|
||||
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient)
|
||||
mocks func(*mock_proxy.MockDialer)
|
||||
wantStatusCode int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient) {
|
||||
openShiftClusterDocument := &api.OpenShiftClusterDocument{
|
||||
ID: resourceID,
|
||||
Key: resourceID,
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
NetworkProfile: api.NetworkProfile{
|
||||
PrivateEndpointIP: privateEndpointIP,
|
||||
},
|
||||
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
},
|
||||
mocks: func(dialer *mock_proxy.MockDialer) {
|
||||
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "GET /test HTTP/1.1\r\nHost: prometheus-k8s-0:9090\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "bad path",
|
||||
r: func(r *http.Request) {
|
||||
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/prometheus/test"
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "Bad Request\n",
|
||||
},
|
||||
{
|
||||
name: "sad database",
|
||||
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient) {
|
||||
openShiftClustersClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantBody: "Internal Server Error\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithOpenShiftClusters(dbOpenShiftClusters)
|
||||
|
||||
checker := testdatabase.NewChecker()
|
||||
|
||||
if tt.fixtureChecker != nil {
|
||||
tt.fixtureChecker(fixture, checker, openShiftClustersClient)
|
||||
}
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet,
|
||||
"https://localhost:8444"+resourceID+"/prometheus/test", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
dialer := mock_proxy.NewMockDialer(ctrl)
|
||||
if tt.mocks != nil {
|
||||
tt.mocks(dialer)
|
||||
}
|
||||
|
||||
aadAuthenticatedRouter := &mux.Router{}
|
||||
|
||||
New(logrus.NewEntry(logrus.StandardLogger()), dbOpenShiftClusters, dialer, aadAuthenticatedRouter)
|
||||
|
||||
if tt.r != nil {
|
||||
tt.r(r)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
|
||||
aadAuthenticatedRouter.ServeHTTP(w, r)
|
||||
|
||||
resp := w.Response()
|
||||
|
||||
if resp.StatusCode != tt.wantStatusCode {
|
||||
t.Error(resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
|
||||
t.Error(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(b) != tt.wantBody {
|
||||
t.Errorf("%q", string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
body string
|
||||
wantbody string
|
||||
}{
|
||||
{
|
||||
name: "makes absolute a hrefs relative",
|
||||
body: `<html><head></head><body><a href="/foo"></a></body></html>`,
|
||||
wantbody: `<html><head></head><body><a href="./foo"></a></body></html>`,
|
||||
},
|
||||
{
|
||||
name: "makes absolute link hrefs relative",
|
||||
body: `<html><head></head><body><link href="/foo"/></body></html>`,
|
||||
wantbody: `<html><head></head><body><link href="./foo"/></body></html>`,
|
||||
},
|
||||
{
|
||||
name: "makes absolute script srcs relative",
|
||||
body: `<html><head></head><body><script src="/foo"></script></body></html>`,
|
||||
wantbody: `<html><head></head><body><script src="./foo"></script></body></html>`,
|
||||
},
|
||||
{
|
||||
name: "makes PATH_PREFIX variable relative",
|
||||
body: `<html><head></head><body><script>var PATH_PREFIX = "";</script></body></html>`,
|
||||
wantbody: `<html><head></head><body><script>var PATH_PREFIX = ".";</script></body></html>`,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &http.Response{
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
},
|
||||
Body: ioutil.NopCloser(strings.NewReader(tt.body)),
|
||||
}
|
||||
|
||||
p := &prometheus{}
|
||||
|
||||
err := p.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(body) != tt.wantbody {
|
||||
t.Errorf("%q", string(body))
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(r.Header.Get("Content-Length"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if length != len(body) {
|
||||
t.Error("length mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,285 @@
|
|||
package portal
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
mock_env "github.com/Azure/ARO-RP/pkg/util/mocks/env"
|
||||
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
"github.com/Azure/ARO-RP/test/util/listener"
|
||||
)
|
||||
|
||||
var (
|
||||
elevatedGroupIDs = []string{"00000000-0000-0000-0000-000000000000"}
|
||||
)
|
||||
|
||||
func TestSecurity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
log := logrus.NewEntry(logrus.StandardLogger())
|
||||
|
||||
controller := gomock.NewController(t)
|
||||
defer controller.Finish()
|
||||
|
||||
_env := mock_env.NewMockCore(controller)
|
||||
_env.EXPECT().DeploymentMode().AnyTimes().Return(deployment.Production)
|
||||
_env.EXPECT().Location().AnyTimes().Return("eastus")
|
||||
_env.EXPECT().TenantID().AnyTimes().Return("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
l := listener.NewListener()
|
||||
defer l.Close()
|
||||
|
||||
sshl := listener.NewListener()
|
||||
defer sshl.Close()
|
||||
|
||||
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("server", nil, nil, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sshkey, _, err := utiltls.GenerateKeyAndCertificate("ssh", nil, nil, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dbOpenShiftClusters, _ := testdatabase.NewFakeOpenShiftClusters()
|
||||
dbPortal, _ := testdatabase.NewFakePortal()
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(servercerts[0])
|
||||
|
||||
c := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: l.DialContext,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: pool,
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
p := NewPortal(_env, log, log, l, sshl, nil, "", serverkey, servercerts, "", nil, nil, make([]byte, 32), sshkey, nil, elevatedGroupIDs, dbOpenShiftClusters, dbPortal, nil)
|
||||
go func() {
|
||||
err := p.Run(ctx)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
request func() (*http.Request, error)
|
||||
checkResponse func(*testing.T, bool, bool, *http.Response)
|
||||
unauthenticatedWantStatusCode int
|
||||
authenticatedWantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "/",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/", nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "/index.js",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/index.js", nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "/api/clusters",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/api/clusters", nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "/api/logout",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodPost, "https://server/api/logout", nil)
|
||||
},
|
||||
authenticatedWantStatusCode: http.StatusSeeOther,
|
||||
},
|
||||
{
|
||||
name: "/callback",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/callback", nil)
|
||||
},
|
||||
authenticatedWantStatusCode: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "/healthz/ready",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/healthz/ready", nil)
|
||||
},
|
||||
unauthenticatedWantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "/kubeconfig/new",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/kubeconfig/new", nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "/prometheus",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/prometheus", nil)
|
||||
},
|
||||
authenticatedWantStatusCode: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "/ssh/new",
|
||||
request: func() (*http.Request, error) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/ssh/new", strings.NewReader("{}"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return req, nil
|
||||
},
|
||||
checkResponse: func(t *testing.T, authenticated, elevated bool, resp *http.Response) {
|
||||
if authenticated && !elevated {
|
||||
var e struct {
|
||||
Error string
|
||||
}
|
||||
err := json.NewDecoder(resp.Body).Decode(&e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if e.Error != "Elevated access is required." {
|
||||
t.Error(e.Error)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "/doesnotexist",
|
||||
request: func() (*http.Request, error) {
|
||||
return http.NewRequest(http.MethodGet, "https://server/doesnotexist", nil)
|
||||
},
|
||||
unauthenticatedWantStatusCode: http.StatusNotFound,
|
||||
authenticatedWantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
} {
|
||||
for _, tt2 := range []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
elevated bool
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated",
|
||||
wantStatusCode: tt.unauthenticatedWantStatusCode,
|
||||
},
|
||||
{
|
||||
name: "authenticated",
|
||||
authenticated: true,
|
||||
wantStatusCode: tt.authenticatedWantStatusCode,
|
||||
},
|
||||
{
|
||||
name: "elevated",
|
||||
authenticated: true,
|
||||
elevated: true,
|
||||
wantStatusCode: tt.authenticatedWantStatusCode,
|
||||
},
|
||||
} {
|
||||
t.Run(tt2.name+tt.name, func(t *testing.T) {
|
||||
req, err := tt.request()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = addCSRF(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tt2.authenticated {
|
||||
var groups []string
|
||||
if tt2.elevated {
|
||||
groups = elevatedGroupIDs
|
||||
}
|
||||
err = addAuth(req, groups)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tt2.wantStatusCode == 0 {
|
||||
if tt2.authenticated {
|
||||
tt2.wantStatusCode = http.StatusOK
|
||||
} else {
|
||||
tt2.wantStatusCode = http.StatusTemporaryRedirect
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != tt2.wantStatusCode {
|
||||
t.Error(resp.StatusCode)
|
||||
}
|
||||
|
||||
if tt.checkResponse != nil {
|
||||
tt.checkResponse(t, tt2.authenticated, tt2.elevated, resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addCSRF(req *http.Request) error {
|
||||
if req.Method != http.MethodPost {
|
||||
return nil
|
||||
}
|
||||
|
||||
req.Header.Set("X-CSRF-Token", base64.StdEncoding.EncodeToString(make([]byte, 64)))
|
||||
|
||||
sc := securecookie.New(make([]byte, 32), nil)
|
||||
sc.SetSerializer(securecookie.JSONEncoder{})
|
||||
|
||||
cookie, err := sc.Encode("_gorilla_csrf", make([]byte, 32))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("Cookie", "_gorilla_csrf="+cookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addAuth(req *http.Request, groups []string) error {
|
||||
store := sessions.NewCookieStore(make([]byte, 32))
|
||||
|
||||
cookie, err := securecookie.EncodeMulti(middleware.SessionName, map[interface{}]interface{}{
|
||||
middleware.SessionKeyUsername: "username",
|
||||
middleware.SessionKeyGroups: groups,
|
||||
middleware.SessionKeyExpires: time.Now().Add(time.Hour),
|
||||
}, store.Codecs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("Cookie", middleware.SessionName+"="+cookie)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,334 @@
|
|||
package ssh
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
utillog "github.com/Azure/ARO-RP/pkg/util/log"
|
||||
"github.com/Azure/ARO-RP/pkg/util/recover"
|
||||
)
|
||||
|
||||
const (
|
||||
sshTimeout = time.Hour
|
||||
)
|
||||
|
||||
func (s *ssh) Run() error {
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
for {
|
||||
c, err := s.l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
err := s.newConn(context.Background(), c)
|
||||
if err != nil {
|
||||
s.log.Warn(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ssh) newConn(ctx context.Context, c1 net.Conn) error {
|
||||
defer c1.Close()
|
||||
|
||||
config := &cryptossh.ServerConfig{}
|
||||
*config = *s.baseServerConfig
|
||||
|
||||
var portalDoc *api.PortalDocument
|
||||
var connmetadata cryptossh.ConnMetadata
|
||||
|
||||
config.PasswordCallback = func(_connmetadata cryptossh.ConnMetadata, pw []byte) (*cryptossh.Permissions, error) {
|
||||
connmetadata = _connmetadata
|
||||
|
||||
password, err := uuid.FromString(string(pw))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid username") // don't echo password attempt to logs
|
||||
}
|
||||
|
||||
portalDoc, err = s.dbPortal.Get(ctx, password.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid username") // don't echo password attempt to logs
|
||||
}
|
||||
|
||||
if portalDoc.Portal.SSH == nil ||
|
||||
connmetadata.User() != strings.SplitN(portalDoc.Portal.Username, "@", 2)[0] {
|
||||
return nil, fmt.Errorf("invalid username")
|
||||
}
|
||||
|
||||
return nil, s.dbPortal.Delete(ctx, portalDoc)
|
||||
}
|
||||
|
||||
conn1, newchannels1, requests1, err := cryptossh.NewServerConn(c1, config)
|
||||
if err != nil {
|
||||
var username string
|
||||
if connmetadata != nil { // after a password attempt
|
||||
username = connmetadata.User()
|
||||
}
|
||||
s.baseAccessLog.WithFields(logrus.Fields{
|
||||
"remote_addr": c1.RemoteAddr().String(),
|
||||
"username": username,
|
||||
}).Warn("authentication failed")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
accessLog := utillog.EnrichWithPath(s.baseAccessLog, portalDoc.Portal.ID)
|
||||
accessLog = accessLog.WithFields(logrus.Fields{
|
||||
"hostname": fmt.Sprintf("master-%d", portalDoc.Portal.SSH.Master),
|
||||
"remote_addr": c1.RemoteAddr().String(),
|
||||
"username": portalDoc.Portal.Username,
|
||||
})
|
||||
|
||||
accessLog.Print("authentication succeeded")
|
||||
|
||||
openShiftDoc, err := s.dbOpenShiftClusters.Get(ctx, strings.ToLower(portalDoc.Portal.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", openShiftDoc.OpenShiftCluster.Properties.NetworkProfile.PrivateEndpointIP, 2200+portalDoc.Portal.SSH.Master)
|
||||
|
||||
c2, err := s.dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer c2.Close()
|
||||
|
||||
key, err := x509.ParsePKCS1PrivateKey(openShiftDoc.OpenShiftCluster.Properties.SSHKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signer, err := cryptossh.NewSignerFromSigner(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn2, newchannels2, requests2, err := cryptossh.NewClientConn(c2, "", &cryptossh.ClientConfig{
|
||||
User: "core",
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
accessLog.Print("connected")
|
||||
defer func() {
|
||||
accessLog.WithFields(logrus.Fields{
|
||||
"duration": time.Since(t).Seconds(),
|
||||
}).Print("disconnected")
|
||||
}()
|
||||
|
||||
keyring := agent.NewKeyring()
|
||||
err = keyring.Add(agent.AddedKey{PrivateKey: key})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.proxyConn(accessLog, keyring, conn1, conn2, newchannels1, newchannels2, requests1, requests2)
|
||||
}
|
||||
|
||||
func (s *ssh) proxyConn(accessLog *logrus.Entry, keyring agent.Agent, conn1, conn2 cryptossh.Conn, newchannels1, newchannels2 <-chan cryptossh.NewChannel, requests1, requests2 <-chan *cryptossh.Request) error {
|
||||
timer := time.NewTimer(sshTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
var sessionOpened bool
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
return nil
|
||||
|
||||
case nc := <-newchannels1:
|
||||
if nc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// on the first c->s session, advertise agent availability
|
||||
var firstSession bool
|
||||
if !sessionOpened && nc.ChannelType() == "session" {
|
||||
firstSession = true
|
||||
sessionOpened = true
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.newChannel(accessLog, nc, conn1, conn2, firstSession)
|
||||
}()
|
||||
|
||||
case nc := <-newchannels2:
|
||||
if nc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// hijack and handle incoming s->c agent requests
|
||||
if nc.ChannelType() == "auth-agent@openssh.com" {
|
||||
go func() {
|
||||
_ = s.handleAgent(accessLog, nc, keyring)
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
_ = s.newChannel(accessLog, nc, conn2, conn1, false)
|
||||
}()
|
||||
}
|
||||
|
||||
case request := <-requests1:
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = s.proxyGlobalRequest(request, conn2)
|
||||
|
||||
case request := <-requests2:
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = s.proxyGlobalRequest(request, conn1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ssh) handleAgent(accessLog *logrus.Entry, nc cryptossh.NewChannel, keyring agent.Agent) error {
|
||||
ch, rs, err := nc.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ch.Close()
|
||||
|
||||
channelLog := accessLog.WithFields(logrus.Fields{
|
||||
"channel": nc.ChannelType(),
|
||||
})
|
||||
channelLog.Printf("opened")
|
||||
defer channelLog.Printf("closed")
|
||||
|
||||
go cryptossh.DiscardRequests(rs)
|
||||
|
||||
return agent.ServeAgent(keyring, ch)
|
||||
}
|
||||
|
||||
func (s *ssh) newChannel(accessLog *logrus.Entry, nc cryptossh.NewChannel, conn1, conn2 cryptossh.Conn, firstSession bool) error {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
ch2, rs2, err := conn2.OpenChannel(nc.ChannelType(), nc.ExtraData())
|
||||
if err, ok := err.(*cryptossh.OpenChannelError); ok {
|
||||
return nc.Reject(err.Reason, err.Message)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ch1, rs1, err := nc.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channelLog := accessLog.WithFields(logrus.Fields{
|
||||
"channel": nc.ChannelType(),
|
||||
})
|
||||
channelLog.Printf("opened")
|
||||
defer channelLog.Printf("closed")
|
||||
|
||||
if firstSession {
|
||||
_, err = ch2.SendRequest("auth-agent-req@openssh.com", true, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.proxyChannel(ch1, ch2, rs1, rs2)
|
||||
}
|
||||
|
||||
func (s *ssh) proxyGlobalRequest(r *cryptossh.Request, c cryptossh.Conn) error {
|
||||
ok, payload, err := c.SendRequest(r.Type, r.WantReply, r.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Reply(ok, payload)
|
||||
}
|
||||
|
||||
func (s *ssh) proxyRequest(r *cryptossh.Request, ch cryptossh.Channel) error {
|
||||
ok, err := ch.SendRequest(r.Type, r.WantReply, r.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Reply(ok, nil)
|
||||
}
|
||||
|
||||
func (s *ssh) proxyChannel(ch1, ch2 cryptossh.Channel, rs1, rs2 <-chan *cryptossh.Request) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(4)
|
||||
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(ch1, ch2)
|
||||
_ = ch1.CloseWrite()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(ch2, ch1)
|
||||
_ = ch2.CloseWrite()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
defer wg.Done()
|
||||
for r := range rs1 {
|
||||
err := s.proxyRequest(r, ch2)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = ch2.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer recover.Panic(s.log)
|
||||
|
||||
defer wg.Done()
|
||||
for r := range rs2 {
|
||||
err := s.proxyRequest(r, ch1)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = ch1.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,443 @@
|
|||
package ssh
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
"github.com/Azure/ARO-RP/test/util/bufferedpipe"
|
||||
"github.com/Azure/ARO-RP/test/util/listener"
|
||||
testlog "github.com/Azure/ARO-RP/test/util/log"
|
||||
)
|
||||
|
||||
// fakeClient runs a fake client on the given connection. It validates the
|
||||
// server key, authenticates, writes a ping request, reads a pong reply, then
|
||||
// closes the connection
|
||||
func fakeClient(c net.Conn, serverKey *rsa.PublicKey, user string, password string) error {
|
||||
publicKey, err := cryptossh.NewPublicKey(serverKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, _, _, err := cryptossh.NewClientConn(c, "", &cryptossh.ClientConfig{
|
||||
HostKeyCallback: cryptossh.FixedHostKey(publicKey),
|
||||
User: user,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(password),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, reply, err := conn.SendRequest("ping", true, []byte("ping"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if string(reply) != "pong" {
|
||||
return fmt.Errorf("invalid reply %q", string(reply))
|
||||
}
|
||||
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
// fakeServer returns a test listener for an SSH server which validates the
|
||||
// client key, reads ping request(s) and writes pong replies
|
||||
func fakeServer(clientKey *rsa.PublicKey) (*listener.Listener, error) {
|
||||
l := listener.NewListener()
|
||||
|
||||
clientPublicKey, err := cryptossh.NewPublicKey(clientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &cryptossh.ServerConfig{
|
||||
PublicKeyCallback: func(conn cryptossh.ConnMetadata, key cryptossh.PublicKey) (*cryptossh.Permissions, error) {
|
||||
if conn.User() != "core" {
|
||||
return nil, fmt.Errorf("invalid user")
|
||||
}
|
||||
if !bytes.Equal(key.Marshal(), clientPublicKey.Marshal()) {
|
||||
return nil, fmt.Errorf("invalid key")
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
key, _, err := tls.GenerateKeyAndCertificate("server", nil, nil, false, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signer, err := cryptossh.NewSignerFromSigner(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.AddHostKey(signer)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
conn, _, requests, err := cryptossh.NewServerConn(c, config)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for request := range requests {
|
||||
if request.Type == "ping" && request.WantReply {
|
||||
err := request.Reply(true, []byte("pong"))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
err := request.Reply(false, nil)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
_ = conn.Wait()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
username := "test"
|
||||
password := "00000000-0000-0000-0000-000000000000"
|
||||
subscriptionID := "10000000-0000-0000-0000-000000000000"
|
||||
resourceGroup := "rg"
|
||||
resourceName := "cluster"
|
||||
resourceID := "/subscriptions/" + subscriptionID + "/resourcegroups/" + resourceGroup + "/providers/microsoft.redhatopenshift/openshiftclusters/" + resourceName
|
||||
privateEndpointIP := "1.2.3.4"
|
||||
|
||||
hostKey, _, err := tls.GenerateKeyAndCertificate("proxy", nil, nil, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clusterKey, _, err := tls.GenerateKeyAndCertificate("cluster", nil, nil, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
l, err := fakeServer(&clusterKey.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
goodOpenShiftClusterDocument := func() *api.OpenShiftClusterDocument {
|
||||
return &api.OpenShiftClusterDocument{
|
||||
ID: resourceID,
|
||||
Key: resourceID,
|
||||
OpenShiftCluster: &api.OpenShiftCluster{
|
||||
Properties: api.OpenShiftClusterProperties{
|
||||
NetworkProfile: api.NetworkProfile{
|
||||
PrivateEndpointIP: privateEndpointIP,
|
||||
},
|
||||
SSHKey: api.SecureBytes(x509.MarshalPKCS1PrivateKey(clusterKey)),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
goodPortalDocument := func(id string) *api.PortalDocument {
|
||||
return &api.PortalDocument{
|
||||
ID: id,
|
||||
Portal: &api.Portal{
|
||||
ID: resourceID,
|
||||
Username: username,
|
||||
SSH: &api.SSH{
|
||||
Master: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
fixtureChecker func(*test, *testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient, *cosmosdb.FakePortalDocumentClient)
|
||||
mocks func(*mock_proxy.MockDialer)
|
||||
wantErrPrefix string
|
||||
wantLogs []map[string]types.GomegaMatcher
|
||||
}
|
||||
|
||||
for _, tt := range []*test{
|
||||
{
|
||||
name: "good",
|
||||
username: username,
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := goodPortalDocument(tt.password)
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
openShiftClusterDocument := goodOpenShiftClusterDocument()
|
||||
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
},
|
||||
mocks: func(dialer *mock_proxy.MockDialer) {
|
||||
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":2201").Return(l.DialContext(ctx, "", ""))
|
||||
},
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"msg": gomega.Equal("authentication succeeded"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
{
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"msg": gomega.Equal("connected"),
|
||||
"hostname": gomega.Equal("master-1"),
|
||||
"resource_group": gomega.Equal(resourceGroup),
|
||||
"resource_id": gomega.Equal(resourceID),
|
||||
"resource_name": gomega.Equal(resourceName),
|
||||
"subscription_id": gomega.Equal(subscriptionID),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
{
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"msg": gomega.Equal("disconnected"),
|
||||
"duration": gomega.BeNumerically(">", 0),
|
||||
"hostname": gomega.Equal("master-1"),
|
||||
"resource_group": gomega.Equal(resourceGroup),
|
||||
"resource_id": gomega.Equal(resourceID),
|
||||
"resource_name": gomega.Equal(resourceName),
|
||||
"subscription_id": gomega.Equal(subscriptionID),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad username",
|
||||
username: "bad",
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := goodPortalDocument(tt.password)
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantErrPrefix: "ssh: handshake failed",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.WarnLevel),
|
||||
"msg": gomega.Equal("authentication failed"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal("bad"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad password, not uuid",
|
||||
username: username,
|
||||
password: "bad",
|
||||
wantErrPrefix: "ssh: handshake failed",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.WarnLevel),
|
||||
"msg": gomega.Equal("authentication failed"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad password",
|
||||
username: username,
|
||||
password: password,
|
||||
wantErrPrefix: "ssh: handshake failed",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.WarnLevel),
|
||||
"msg": gomega.Equal("authentication failed"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not ssh record",
|
||||
username: username,
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := goodPortalDocument(tt.password)
|
||||
portalDocument.Portal.SSH = nil
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
checker.AddPortalDocuments(portalDocument)
|
||||
},
|
||||
wantErrPrefix: "ssh: handshake failed",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.WarnLevel),
|
||||
"msg": gomega.Equal("authentication failed"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sad openshiftClusters database",
|
||||
username: username,
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := goodPortalDocument(tt.password)
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
|
||||
openShiftClustersClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantErrPrefix: "EOF",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"msg": gomega.Equal("authentication succeeded"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sad portal database",
|
||||
username: username,
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantErrPrefix: "ssh: handshake failed",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.WarnLevel),
|
||||
"msg": gomega.Equal("authentication failed"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sad dialer",
|
||||
username: username,
|
||||
password: password,
|
||||
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalDocument := goodPortalDocument(tt.password)
|
||||
fixture.AddPortalDocuments(portalDocument)
|
||||
openShiftClusterDocument := goodOpenShiftClusterDocument()
|
||||
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
|
||||
},
|
||||
mocks: func(dialer *mock_proxy.MockDialer) {
|
||||
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":2201").Return(nil, fmt.Errorf("sad"))
|
||||
},
|
||||
wantErrPrefix: "EOF",
|
||||
wantLogs: []map[string]types.GomegaMatcher{
|
||||
{
|
||||
"level": gomega.Equal(logrus.InfoLevel),
|
||||
"msg": gomega.Equal("authentication succeeded"),
|
||||
"remote_addr": gomega.Not(gomega.BeEmpty()),
|
||||
"username": gomega.Equal(username),
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dbPortal, portalClient := testdatabase.NewFakePortal()
|
||||
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
|
||||
|
||||
fixture := testdatabase.NewFixture().
|
||||
WithOpenShiftClusters(dbOpenShiftClusters).
|
||||
WithPortal(dbPortal)
|
||||
|
||||
checker := testdatabase.NewChecker()
|
||||
|
||||
if tt.fixtureChecker != nil {
|
||||
tt.fixtureChecker(tt, fixture, checker, openShiftClustersClient, portalClient)
|
||||
}
|
||||
|
||||
err := fixture.Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client, client1 := bufferedpipe.New()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
dialer := mock_proxy.NewMockDialer(ctrl)
|
||||
|
||||
if tt.mocks != nil {
|
||||
tt.mocks(dialer)
|
||||
}
|
||||
|
||||
hook, log := testlog.New()
|
||||
|
||||
s, err := New(nil, nil, log, nil, hostKey, nil, dbOpenShiftClusters, dbPortal, dialer, &mux.Router{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_ = s.newConn(context.Background(), client1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = fakeClient(client, &hostKey.PublicKey, tt.username, tt.password)
|
||||
if err != nil && !strings.HasPrefix(err.Error(), tt.wantErrPrefix) ||
|
||||
err == nil && tt.wantErrPrefix != "" {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
<-done
|
||||
|
||||
openShiftClustersClient.SetError(nil)
|
||||
portalClient.SetError(nil)
|
||||
|
||||
for _, err = range checker.CheckOpenShiftClusters(openShiftClustersClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for _, err = range checker.CheckPortals(portalClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = testlog.AssertLoggingOutput(hook, tt.wantLogs)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package ssh
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/api/validate"
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/proxy"
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
)
|
||||
|
||||
const (
|
||||
sshNewTimeout = time.Minute
|
||||
)
|
||||
|
||||
type ssh struct {
|
||||
env env.Core
|
||||
log *logrus.Entry
|
||||
baseAccessLog *logrus.Entry
|
||||
l net.Listener
|
||||
|
||||
elevatedGroupIDs []string
|
||||
|
||||
dbOpenShiftClusters database.OpenShiftClusters
|
||||
dbPortal database.Portal
|
||||
|
||||
dialer proxy.Dialer
|
||||
|
||||
baseServerConfig *cryptossh.ServerConfig
|
||||
newPassword func() string
|
||||
}
|
||||
|
||||
func New(env env.Core,
|
||||
log *logrus.Entry,
|
||||
baseAccessLog *logrus.Entry,
|
||||
l net.Listener,
|
||||
hostKey *rsa.PrivateKey,
|
||||
elevatedGroupIDs []string,
|
||||
dbOpenShiftClusters database.OpenShiftClusters,
|
||||
dbPortal database.Portal,
|
||||
dialer proxy.Dialer,
|
||||
aadAuthenticatedRouter *mux.Router) (*ssh, error) {
|
||||
s := &ssh{
|
||||
env: env,
|
||||
log: log,
|
||||
baseAccessLog: baseAccessLog,
|
||||
l: l,
|
||||
|
||||
elevatedGroupIDs: elevatedGroupIDs,
|
||||
|
||||
dbOpenShiftClusters: dbOpenShiftClusters,
|
||||
dbPortal: dbPortal,
|
||||
|
||||
dialer: dialer,
|
||||
|
||||
baseServerConfig: &cryptossh.ServerConfig{},
|
||||
newPassword: func() string { return uuid.NewV4().String() },
|
||||
}
|
||||
|
||||
signer, err := cryptossh.NewSignerFromSigner(hostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.baseServerConfig.AddHostKey(signer)
|
||||
|
||||
aadAuthenticatedRouter.NewRoute().Methods(http.MethodPost).Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/ssh/new").HandlerFunc(s.new)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type request struct {
|
||||
Master int `json:"master,omitempty"`
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Command string `json:"command,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ssh) new(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
parts := strings.Split(r.URL.Path, "/")
|
||||
if len(parts) < 9 {
|
||||
http.Error(w, "invalid resourceId", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resourceID := strings.Join(parts[:9], "/")
|
||||
if !validate.RxClusterID.MatchString(resourceID) {
|
||||
http.Error(w, fmt.Sprintf("invalid resourceId %q", resourceID), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
mediatype, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if mediatype != "application/json" {
|
||||
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
|
||||
var req *request
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil || req.Master < 0 || req.Master > 2 {
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
elevated := middleware.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))
|
||||
if !elevated {
|
||||
s.sendResponse(w, &response{
|
||||
Error: "Elevated access is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := r.Context().Value(middleware.ContextKeyUsername).(string)
|
||||
username = strings.SplitN(username, "@", 2)[0]
|
||||
|
||||
password := s.newPassword()
|
||||
|
||||
portalDoc := &api.PortalDocument{
|
||||
ID: password,
|
||||
TTL: int(sshNewTimeout / time.Second),
|
||||
Portal: &api.Portal{
|
||||
Username: ctx.Value(middleware.ContextKeyUsername).(string),
|
||||
ID: resourceID,
|
||||
SSH: &api.SSH{
|
||||
Master: req.Master,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = s.dbPortal.Create(ctx, portalDoc)
|
||||
if err != nil {
|
||||
s.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
host := r.Host
|
||||
if strings.ContainsRune(r.Host, ':') {
|
||||
host, _, err = net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
s.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
port := ""
|
||||
if s.env.DeploymentMode() == deployment.Development {
|
||||
port = "-p 2222 "
|
||||
}
|
||||
|
||||
s.sendResponse(w, &response{
|
||||
Command: fmt.Sprintf("ssh %s%s@%s", port, username, host),
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ssh) sendResponse(w http.ResponseWriter, resp *response) {
|
||||
b, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
s.internalServerError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
func (s *ssh) internalServerError(w http.ResponseWriter, err error) {
|
||||
s.log.Warn(err)
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
package ssh
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/middleware"
|
||||
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
|
||||
"github.com/Azure/ARO-RP/pkg/util/deployment"
|
||||
mock_env "github.com/Azure/ARO-RP/pkg/util/mocks/env"
|
||||
"github.com/Azure/ARO-RP/pkg/util/tls"
|
||||
testdatabase "github.com/Azure/ARO-RP/test/database"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
|
||||
elevatedGroupIDs := []string{"10000000-0000-0000-0000-000000000000"}
|
||||
username := "username"
|
||||
password := "password"
|
||||
master := 0
|
||||
|
||||
hostKey, _, err := tls.GenerateKeyAndCertificate("proxy", nil, nil, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
deploymentMode deployment.Mode
|
||||
r func(*http.Request)
|
||||
checker func(*testdatabase.Checker, *cosmosdb.FakePortalDocumentClient)
|
||||
wantStatusCode int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
checker: func(checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
checker.AddPortalDocuments(&api.PortalDocument{
|
||||
ID: password,
|
||||
TTL: 60,
|
||||
Portal: &api.Portal{
|
||||
Username: username,
|
||||
ID: resourceID,
|
||||
SSH: &api.SSH{
|
||||
Master: master,
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "{\n \"command\": \"ssh username@localhost\",\n \"password\": \"password\"\n}",
|
||||
},
|
||||
{
|
||||
name: "bad path",
|
||||
r: func(r *http.Request) {
|
||||
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/ssh/new"
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "invalid resourceId \"/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster\"\n",
|
||||
},
|
||||
{
|
||||
name: "bad content type",
|
||||
r: func(r *http.Request) {
|
||||
r.Header.Set("Content-Type", "bad")
|
||||
},
|
||||
wantStatusCode: http.StatusUnsupportedMediaType,
|
||||
wantBody: "Unsupported Media Type\n",
|
||||
},
|
||||
{
|
||||
name: "empty request",
|
||||
r: func(r *http.Request) {
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(nil))
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "Bad Request\n",
|
||||
},
|
||||
{
|
||||
name: "junk request",
|
||||
r: func(r *http.Request) {
|
||||
r.Body = ioutil.NopCloser(strings.NewReader("{{"))
|
||||
},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantBody: "Bad Request\n",
|
||||
},
|
||||
{
|
||||
name: "not elevated",
|
||||
r: func(r *http.Request) {
|
||||
*r = *r.WithContext(context.WithValue(r.Context(), middleware.ContextKeyGroups, []string{}))
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "{\n \"error\": \"Elevated access is required.\"\n}",
|
||||
},
|
||||
{
|
||||
name: "sad database",
|
||||
checker: func(checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
|
||||
portalClient.SetError(fmt.Errorf("sad"))
|
||||
},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantBody: "Internal Server Error\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
dbPortal, portalClient := testdatabase.NewFakePortal()
|
||||
|
||||
checker := testdatabase.NewChecker()
|
||||
|
||||
if tt.checker != nil {
|
||||
tt.checker(checker, portalClient)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, middleware.ContextKeyUsername, username)
|
||||
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, elevatedGroupIDs)
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
"https://localhost:8444"+resourceID+"/ssh/new", strings.NewReader(fmt.Sprintf(`{"master":%d}`, master)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
env := mock_env.NewMockCore(ctrl)
|
||||
env.EXPECT().DeploymentMode().AnyTimes().Return(tt.deploymentMode)
|
||||
|
||||
aadAuthenticatedRouter := &mux.Router{}
|
||||
|
||||
s, err := New(env, logrus.NewEntry(logrus.StandardLogger()), nil, nil, hostKey, elevatedGroupIDs, nil, dbPortal, nil, aadAuthenticatedRouter)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s.newPassword = func() string { return password }
|
||||
|
||||
if tt.r != nil {
|
||||
tt.r(r)
|
||||
}
|
||||
|
||||
w := responsewriter.New(r)
|
||||
|
||||
aadAuthenticatedRouter.ServeHTTP(w, r)
|
||||
|
||||
portalClient.SetError(nil)
|
||||
|
||||
for _, err = range checker.CheckPortals(portalClient) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
resp := w.Response()
|
||||
|
||||
if resp.StatusCode != tt.wantStatusCode {
|
||||
t.Error(resp.StatusCode)
|
||||
}
|
||||
|
||||
wantContentType := "application/json"
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
wantContentType = "text/plain; charset=utf-8"
|
||||
}
|
||||
if resp.Header.Get("Content-Type") != wantContentType {
|
||||
t.Error(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(b) != tt.wantBody {
|
||||
t.Errorf("%q", string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package clientcache
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientCache is a cache for *http.Clients. It allows us to reuse clients and
|
||||
// connections across multiple incoming calls, saving us TCP, TLS and proxy
|
||||
// initialisations.
|
||||
type ClientCache interface {
|
||||
Get(interface{}) *http.Client
|
||||
Put(interface{}, *http.Client)
|
||||
}
|
||||
|
||||
type clientCache struct {
|
||||
mu sync.Mutex
|
||||
now func() time.Time
|
||||
ttl time.Duration
|
||||
m map[interface{}]*v
|
||||
}
|
||||
|
||||
type v struct {
|
||||
expires time.Time
|
||||
cli *http.Client
|
||||
}
|
||||
|
||||
// New returns a new ClientCache
|
||||
func New(ttl time.Duration) ClientCache {
|
||||
return &clientCache{
|
||||
now: time.Now,
|
||||
ttl: ttl,
|
||||
m: map[interface{}]*v{},
|
||||
}
|
||||
}
|
||||
|
||||
// call holding c.mu
|
||||
func (c *clientCache) expire() {
|
||||
now := c.now()
|
||||
for k, v := range c.m {
|
||||
if now.After(v.expires) {
|
||||
v.cli.CloseIdleConnections()
|
||||
delete(c.m, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientCache) Get(k interface{}) (cli *http.Client) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if v := c.m[k]; v != nil {
|
||||
v.expires = c.now().Add(c.ttl)
|
||||
cli = v.cli
|
||||
}
|
||||
|
||||
c.expire()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *clientCache) Put(k interface{}, cli *http.Client) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.m[k] = &v{
|
||||
expires: c.now().Add(c.ttl),
|
||||
cli: cli,
|
||||
}
|
||||
|
||||
c.expire()
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package clientcache
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientCache(t *testing.T) {
|
||||
var now time.Time
|
||||
|
||||
cli1 := &http.Client{}
|
||||
cli2 := &http.Client{}
|
||||
|
||||
c := New(1)
|
||||
c.(*clientCache).now = func() time.Time { return now }
|
||||
|
||||
// t = 0: put(1), get(1)
|
||||
c.Put(1, cli1)
|
||||
if c.Get(1) != cli1 {
|
||||
t.Error(c.Get(1), cli1)
|
||||
}
|
||||
|
||||
now = now.Add(2)
|
||||
|
||||
// t = 2: put(1), get(1) (cli1's ttl should be reset before expiring)
|
||||
if c.Get(1) != cli1 {
|
||||
t.Error(c.Get(1), cli1)
|
||||
}
|
||||
|
||||
now = now.Add(2)
|
||||
|
||||
// t = 4: put(2), get(2) (cli1 should be expired)
|
||||
c.Put(2, cli2)
|
||||
if c.Get(2) != cli2 {
|
||||
t.Error(c.Get(2), cli2)
|
||||
}
|
||||
|
||||
if c.Get(1) != nil {
|
||||
t.Error(c.Get(1))
|
||||
}
|
||||
|
||||
now = now.Add(2)
|
||||
|
||||
// t = 6: cli2 should be expired
|
||||
c.(*clientCache).expire()
|
||||
|
||||
if c.Get(2) != nil {
|
||||
t.Error(c.Get(2))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package responsewriter
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ResponseWriter represents a ResponseWriter
|
||||
type ResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
Response() *http.Response
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
bytes.Buffer
|
||||
r *http.Request
|
||||
h http.Header
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// New returns an http.ResponseWriter on which you can later call Response() to
|
||||
// generate an *http.Response.
|
||||
func New(r *http.Request) ResponseWriter {
|
||||
return &responseWriter{
|
||||
r: r,
|
||||
h: http.Header{},
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.h
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (w *responseWriter) Response() *http.Response {
|
||||
return &http.Response{
|
||||
ProtoMajor: w.r.ProtoMajor,
|
||||
ProtoMinor: w.r.ProtoMinor,
|
||||
StatusCode: w.statusCode,
|
||||
Header: w.h,
|
||||
Body: ioutil.NopCloser(&w.Buffer),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package responsewriter
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResponseWriter(t *testing.T) {
|
||||
w := New(&http.Request{ProtoMajor: 1, ProtoMinor: 1})
|
||||
|
||||
http.NotFound(w, nil)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
_ = w.Response().Write(buf)
|
||||
|
||||
if buf.String() != "HTTP/1.1 404 Not Found\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\nX-Content-Type-Options: nosniff\r\n\r\n404 page not found\n" {
|
||||
t.Error(buf.String())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package proxy
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
//go:generate rm -rf ../util/mocks/$GOPACKAGE
|
||||
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Dialer
|
||||
//go:generate go run ../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../util/mocks/$GOPACKAGE/$GOPACKAGE.go
|
|
@ -1,5 +1,5 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/Azure/ARO-RP/pkg/env (interfaces: Interface)
|
||||
// Source: github.com/Azure/ARO-RP/pkg/env (interfaces: Core,Interface)
|
||||
|
||||
// Package mock_env is a generated GoMock package.
|
||||
package mock_env
|
||||
|
@ -21,6 +21,128 @@ import (
|
|||
refreshable "github.com/Azure/ARO-RP/pkg/util/refreshable"
|
||||
)
|
||||
|
||||
// MockCore is a mock of Core interface
|
||||
type MockCore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCoreMockRecorder
|
||||
}
|
||||
|
||||
// MockCoreMockRecorder is the mock recorder for MockCore
|
||||
type MockCoreMockRecorder struct {
|
||||
mock *MockCore
|
||||
}
|
||||
|
||||
// NewMockCore creates a new mock instance
|
||||
func NewMockCore(ctrl *gomock.Controller) *MockCore {
|
||||
mock := &MockCore{ctrl: ctrl}
|
||||
mock.recorder = &MockCoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCore) EXPECT() *MockCoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DeploymentMode mocks base method
|
||||
func (m *MockCore) DeploymentMode() deployment.Mode {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeploymentMode")
|
||||
ret0, _ := ret[0].(deployment.Mode)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeploymentMode indicates an expected call of DeploymentMode
|
||||
func (mr *MockCoreMockRecorder) DeploymentMode() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeploymentMode", reflect.TypeOf((*MockCore)(nil).DeploymentMode))
|
||||
}
|
||||
|
||||
// Environment mocks base method
|
||||
func (m *MockCore) Environment() *azure.Environment {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Environment")
|
||||
ret0, _ := ret[0].(*azure.Environment)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Environment indicates an expected call of Environment
|
||||
func (mr *MockCoreMockRecorder) Environment() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Environment", reflect.TypeOf((*MockCore)(nil).Environment))
|
||||
}
|
||||
|
||||
// Location mocks base method
|
||||
func (m *MockCore) Location() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Location")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Location indicates an expected call of Location
|
||||
func (mr *MockCoreMockRecorder) Location() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Location", reflect.TypeOf((*MockCore)(nil).Location))
|
||||
}
|
||||
|
||||
// NewRPAuthorizer mocks base method
|
||||
func (m *MockCore) NewRPAuthorizer(arg0 string) (autorest.Authorizer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "NewRPAuthorizer", arg0)
|
||||
ret0, _ := ret[0].(autorest.Authorizer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// NewRPAuthorizer indicates an expected call of NewRPAuthorizer
|
||||
func (mr *MockCoreMockRecorder) NewRPAuthorizer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRPAuthorizer", reflect.TypeOf((*MockCore)(nil).NewRPAuthorizer), arg0)
|
||||
}
|
||||
|
||||
// ResourceGroup mocks base method
|
||||
func (m *MockCore) ResourceGroup() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResourceGroup")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ResourceGroup indicates an expected call of ResourceGroup
|
||||
func (mr *MockCoreMockRecorder) ResourceGroup() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceGroup", reflect.TypeOf((*MockCore)(nil).ResourceGroup))
|
||||
}
|
||||
|
||||
// SubscriptionID mocks base method
|
||||
func (m *MockCore) SubscriptionID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SubscriptionID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SubscriptionID indicates an expected call of SubscriptionID
|
||||
func (mr *MockCoreMockRecorder) SubscriptionID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionID", reflect.TypeOf((*MockCore)(nil).SubscriptionID))
|
||||
}
|
||||
|
||||
// TenantID mocks base method
|
||||
func (m *MockCore) TenantID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TenantID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TenantID indicates an expected call of TenantID
|
||||
func (mr *MockCoreMockRecorder) TenantID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TenantID", reflect.TypeOf((*MockCore)(nil).TenantID))
|
||||
}
|
||||
|
||||
// MockInterface is a mock of Interface interface
|
||||
type MockInterface struct {
|
||||
ctrl *gomock.Controller
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/Azure/ARO-RP/pkg/proxy (interfaces: Dialer)
|
||||
|
||||
// Package mock_proxy is a generated GoMock package.
|
||||
package mock_proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDialer is a mock of Dialer interface
|
||||
type MockDialer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDialerMockRecorder
|
||||
}
|
||||
|
||||
// MockDialerMockRecorder is the mock recorder for MockDialer
|
||||
type MockDialerMockRecorder struct {
|
||||
mock *MockDialer
|
||||
}
|
||||
|
||||
// NewMockDialer creates a new mock instance
|
||||
func NewMockDialer(ctrl *gomock.Controller) *MockDialer {
|
||||
mock := &MockDialer{ctrl: ctrl}
|
||||
mock.recorder = &MockDialerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockDialer) EXPECT() *MockDialerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DialContext mocks base method
|
||||
func (m *MockDialer) DialContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialContext", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DialContext indicates an expected call of DialContext
|
||||
func (mr *MockDialerMockRecorder) DialContext(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockDialer)(nil).DialContext), arg0, arg1, arg2)
|
||||
}
|
|
@ -40,7 +40,13 @@ func RestConfig(dialer proxy.Dialer, oc *api.OpenShiftCluster) (*rest.Config, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
restconfig.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
restconfig.Dial = DialContext(dialer, oc)
|
||||
|
||||
return restconfig, nil
|
||||
}
|
||||
|
||||
func DialContext(dialer proxy.Dialer, oc *api.OpenShiftCluster) func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if network != "tcp" {
|
||||
return nil, fmt.Errorf("unimplemented network %q", network)
|
||||
}
|
||||
|
@ -52,6 +58,4 @@ func RestConfig(dialer proxy.Dialer, oc *api.OpenShiftCluster) (*rest.Config, er
|
|||
|
||||
return dialer.DialContext(ctx, network, oc.Properties.NetworkProfile.PrivateEndpointIP+":"+port)
|
||||
}
|
||||
|
||||
return restconfig, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package roundtripper
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type RoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
|
@ -21,6 +21,7 @@ type Checker struct {
|
|||
subscriptionDocuments []*api.SubscriptionDocument
|
||||
billingDocuments []*api.BillingDocument
|
||||
asyncOperationDocuments []*api.AsyncOperationDocument
|
||||
portalDocuments []*api.PortalDocument
|
||||
}
|
||||
|
||||
func NewChecker() *Checker {
|
||||
|
@ -43,6 +44,10 @@ func (f *Checker) AddAsyncOperationDocuments(docs ...*api.AsyncOperationDocument
|
|||
f.asyncOperationDocuments = append(f.asyncOperationDocuments, docs...)
|
||||
}
|
||||
|
||||
func (f *Checker) AddPortalDocuments(docs ...*api.PortalDocument) {
|
||||
f.portalDocuments = append(f.portalDocuments, docs...)
|
||||
}
|
||||
|
||||
func (f *Checker) CheckOpenShiftClusters(openShiftClusters *cosmosdb.FakeOpenShiftClusterDocumentClient) (errs []error) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -134,3 +139,23 @@ func (f *Checker) CheckAsyncOperations(asyncOperations *cosmosdb.FakeAsyncOperat
|
|||
|
||||
return errs
|
||||
}
|
||||
|
||||
func (f *Checker) CheckPortals(portals *cosmosdb.FakePortalDocumentClient) (errs []error) {
|
||||
ctx := context.Background()
|
||||
|
||||
all, err := portals.ListAll(ctx, nil)
|
||||
if err != nil {
|
||||
return []error{err}
|
||||
}
|
||||
|
||||
if len(f.portalDocuments) != 0 && len(all.PortalDocuments) == len(f.portalDocuments) {
|
||||
diff := deep.Equal(all.PortalDocuments, f.portalDocuments)
|
||||
for _, i := range diff {
|
||||
errs = append(errs, errors.New(i))
|
||||
}
|
||||
} else if len(all.PortalDocuments) != 0 || len(f.portalDocuments) != 0 {
|
||||
errs = append(errs, fmt.Errorf("portals length different, %d vs %d", len(all.PortalDocuments), len(f.portalDocuments)))
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
|
|
@ -17,11 +17,13 @@ type Fixture struct {
|
|||
subscriptionDocuments []*api.SubscriptionDocument
|
||||
billingDocuments []*api.BillingDocument
|
||||
asyncOperationDocuments []*api.AsyncOperationDocument
|
||||
portalDocuments []*api.PortalDocument
|
||||
|
||||
openShiftClustersDatabase database.OpenShiftClusters
|
||||
billingDatabase database.Billing
|
||||
subscriptionsDatabase database.Subscriptions
|
||||
asyncOperationsDatabase database.AsyncOperations
|
||||
portalDatabase database.Portal
|
||||
}
|
||||
|
||||
func NewFixture() *Fixture {
|
||||
|
@ -48,6 +50,11 @@ func (f *Fixture) WithAsyncOperations(db database.AsyncOperations) *Fixture {
|
|||
return f
|
||||
}
|
||||
|
||||
func (f *Fixture) WithPortal(db database.Portal) *Fixture {
|
||||
f.portalDatabase = db
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Fixture) AddOpenShiftClusterDocuments(docs ...*api.OpenShiftClusterDocument) {
|
||||
f.openshiftClusterDocuments = append(f.openshiftClusterDocuments, docs...)
|
||||
}
|
||||
|
@ -64,6 +71,10 @@ func (f *Fixture) AddAsyncOperationDocuments(docs ...*api.AsyncOperationDocument
|
|||
f.asyncOperationDocuments = append(f.asyncOperationDocuments, docs...)
|
||||
}
|
||||
|
||||
func (f *Fixture) AddPortalDocuments(docs ...*api.PortalDocument) {
|
||||
f.portalDocuments = append(f.portalDocuments, docs...)
|
||||
}
|
||||
|
||||
func (f *Fixture) Create() error {
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -98,5 +109,12 @@ func (f *Fixture) Create() error {
|
|||
}
|
||||
}
|
||||
|
||||
for _, i := range f.portalDocuments {
|
||||
_, err := f.portalDatabase.Create(ctx, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -47,3 +47,9 @@ func NewFakeAsyncOperations() (db database.AsyncOperations, client *cosmosdb.Fak
|
|||
db = database.NewAsyncOperationsWithProvidedClient(client)
|
||||
return db, client
|
||||
}
|
||||
|
||||
func NewFakePortal() (db database.Portal, client *cosmosdb.FakePortalDocumentClient) {
|
||||
client = cosmosdb.NewFakePortalDocumentClient(jsonHandle)
|
||||
db = database.NewPortalWithProvidedClient(client)
|
||||
return db, client
|
||||
}
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
package bufferedpipe
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// New returns two net.Conns representing either side of a buffered pipe. It's
|
||||
// like net.Pipe() but with buffering. Note that deadlines are currently not
|
||||
// implemented.
|
||||
func New() (net.Conn, net.Conn) {
|
||||
p := &p{
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
|
||||
return &conn{p, 0}, &conn{p, 1}
|
||||
}
|
||||
|
||||
type p struct {
|
||||
cond *sync.Cond
|
||||
buf [2]bytes.Buffer
|
||||
closed [2]bool
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
p *p
|
||||
n int
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (int, error) {
|
||||
c.p.cond.L.Lock()
|
||||
defer c.p.cond.L.Unlock()
|
||||
|
||||
for {
|
||||
if c.p.closed[c.n] {
|
||||
// Read() concurrently with, or after Close()
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
|
||||
if c.p.buf[c.n^1].Len() > 0 {
|
||||
return c.p.buf[c.n^1].Read(b)
|
||||
}
|
||||
|
||||
if c.p.closed[c.n^1] {
|
||||
// Other side closed and read buffer is drained
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
c.p.cond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (int, error) {
|
||||
c.p.cond.L.Lock()
|
||||
defer c.p.cond.L.Unlock()
|
||||
|
||||
if c.p.closed[c.n] {
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
|
||||
c.p.cond.Broadcast()
|
||||
return c.p.buf[c.n].Write(b)
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.p.cond.L.Lock()
|
||||
defer c.p.cond.L.Unlock()
|
||||
|
||||
if c.p.closed[c.n] {
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
|
||||
c.p.closed[c.n] = true
|
||||
c.p.cond.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr { return &addr{} }
|
||||
func (c *conn) RemoteAddr() net.Addr { return &addr{} }
|
||||
func (c *conn) SetDeadline(time.Time) error { return errors.New("not implemented") }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return errors.New("not implemented") }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return errors.New("not implemented") }
|
||||
|
||||
type addr struct{}
|
||||
|
||||
func (addr) Network() string { return "bufferedpipe" }
|
||||
func (addr) String() string { return "bufferedpipe" }
|
|
@ -7,6 +7,8 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/Azure/ARO-RP/test/util/bufferedpipe"
|
||||
)
|
||||
|
||||
type addr struct{}
|
||||
|
@ -46,7 +48,11 @@ func (*Listener) Addr() net.Addr {
|
|||
}
|
||||
|
||||
func (l *Listener) DialContext(context.Context, string, string) (net.Conn, error) {
|
||||
c1, c2 := net.Pipe()
|
||||
c1, c2 := bufferedpipe.New()
|
||||
l.c <- c1
|
||||
return c2, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Enqueue(c net.Conn) {
|
||||
l.c <- c
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче