This commit is contained in:
Jim Minter 2020-09-25 17:20:03 -05:00
Родитель c38da832db
Коммит 9e5c4f8930
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 0730CBDA10D1A2D3
73 изменённых файлов: 7765 добавлений и 108 удалений

Просмотреть файл

@ -24,6 +24,7 @@ func usage() {
fmt.Fprintf(flag.CommandLine.Output(), " %s deploy config.yaml location\n", os.Args[0])
fmt.Fprintf(flag.CommandLine.Output(), " %s mirror [release_image...]\n", os.Args[0])
fmt.Fprintf(flag.CommandLine.Output(), " %s monitor\n", os.Args[0])
fmt.Fprintf(flag.CommandLine.Output(), " %s portal\n", os.Args[0])
fmt.Fprintf(flag.CommandLine.Output(), " %s rp\n", os.Args[0])
fmt.Fprintf(flag.CommandLine.Output(), " %s operator {master,worker}\n", os.Args[0])
flag.PrintDefaults()
@ -46,6 +47,9 @@ func main() {
var err error
switch strings.ToLower(flag.Arg(0)) {
case "deploy":
checkArgs(3)
err = deploy(ctx, log)
case "mirror":
checkMinArgs(1)
err = mirror(ctx, log)
@ -55,9 +59,9 @@ func main() {
case "rp":
checkArgs(1)
err = rp(ctx, log)
case "deploy":
checkArgs(3)
err = deploy(ctx, log)
case "portal":
checkArgs(1)
err = portal(ctx, log)
case "operator":
checkArgs(2)
err = operator(ctx, log)

181
cmd/aro/portal.go Normal file
Просмотреть файл

@ -0,0 +1,181 @@
package main
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/x509"
"fmt"
"net"
"os"
"strings"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/deploy/generator"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/metrics/noop"
pkgportal "github.com/Azure/ARO-RP/pkg/portal"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/deployment"
"github.com/Azure/ARO-RP/pkg/util/encryption"
"github.com/Azure/ARO-RP/pkg/util/keyvault"
)
func portal(ctx context.Context, log *logrus.Entry) error {
_env, err := env.NewCore(ctx, log)
if err != nil {
return err
}
if _env.DeploymentMode() != deployment.Development {
for _, key := range []string{
"PORTAL_HOSTNAME",
} {
if _, found := os.LookupEnv(key); !found {
return fmt.Errorf("environment variable %q unset", key)
}
}
}
for _, key := range []string{
"AZURE_PORTAL_CLIENT_ID",
"AZURE_PORTAL_ACCESS_GROUP_IDS",
"AZURE_PORTAL_ELEVATED_GROUP_IDS",
} {
if _, found := os.LookupEnv(key); !found {
return fmt.Errorf("environment variable %q unset", key)
}
}
groupIDs, err := parseGroupIDs(os.Getenv("AZURE_PORTAL_ACCESS_GROUP_IDS"))
if err != nil {
return err
}
elevatedGroupIDs, err := parseGroupIDs(os.Getenv("AZURE_PORTAL_ELEVATED_GROUP_IDS"))
if err != nil {
return err
}
rpKVAuthorizer, err := _env.NewRPAuthorizer(_env.Environment().ResourceIdentifiers.KeyVault)
if err != nil {
return err
}
// TODO: should not be using the service keyvault here
serviceKeyvaultURI, err := keyvault.URI(_env, generator.ServiceKeyvaultSuffix)
if err != nil {
return err
}
serviceKeyvault := keyvault.NewManager(rpKVAuthorizer, serviceKeyvaultURI)
key, err := serviceKeyvault.GetBase64Secret(ctx, env.EncryptionSecretName)
if err != nil {
return err
}
cipher, err := encryption.NewXChaCha20Poly1305(ctx, key)
if err != nil {
return err
}
dbc, err := database.NewDatabaseClient(ctx, log.WithField("component", "database"), _env, &noop.Noop{}, cipher)
if err != nil {
return err
}
dbOpenShiftClusters, err := database.NewOpenShiftClusters(ctx, _env.DeploymentMode(), dbc)
if err != nil {
return err
}
dbPortal, err := database.NewPortal(ctx, _env.DeploymentMode(), dbc)
if err != nil {
return err
}
portalKeyvaultURI, err := keyvault.URI(_env, generator.PortalKeyvaultSuffix)
if err != nil {
return err
}
portalKeyvault := keyvault.NewManager(rpKVAuthorizer, portalKeyvaultURI)
servingKey, servingCerts, err := portalKeyvault.GetCertificateSecret(ctx, env.PortalServerSecretName)
if err != nil {
return err
}
clientKey, clientCerts, err := portalKeyvault.GetCertificateSecret(ctx, env.PortalServerClientSecretName)
if err != nil {
return err
}
sessionKey, err := portalKeyvault.GetBase64Secret(ctx, env.PortalServerSessionKeySecretName)
if err != nil {
return err
}
b, err := portalKeyvault.GetBase64Secret(ctx, env.PortalServerSSHKeySecretName)
if err != nil {
return err
}
sshKey, err := x509.ParsePKCS1PrivateKey(b)
if err != nil {
return err
}
dialer, err := proxy.NewDialer(_env.DeploymentMode())
if err != nil {
return err
}
clientID := os.Getenv("AZURE_PORTAL_CLIENT_ID")
verifier, err := middleware.NewVerifier(ctx, _env.TenantID(), clientID)
if err != nil {
return err
}
hostname := "localhost:8444"
address := "localhost:8444"
sshAddress := "localhost:2222"
if _env.DeploymentMode() != deployment.Development {
hostname = os.Getenv("PORTAL_HOSTNAME")
address = ":8444"
sshAddress = ":2222"
}
l, err := net.Listen("tcp", address)
if err != nil {
return err
}
sshl, err := net.Listen("tcp", sshAddress)
if err != nil {
return err
}
log.Print("listening")
p := pkgportal.NewPortal(_env, log.WithField("component", "portal"), log.WithField("component", "portal-access"), l, sshl, verifier, hostname, servingKey, servingCerts, clientID, clientKey, clientCerts, sessionKey, sshKey, groupIDs, elevatedGroupIDs, dbOpenShiftClusters, dbPortal, dialer)
return p.Run(ctx)
}
func parseGroupIDs(_groupIDs string) ([]string, error) {
groupIDs := strings.Split(_groupIDs, ",")
for _, groupID := range groupIDs {
_, err := uuid.FromString(groupID)
if err != nil {
return nil, err
}
}
return groupIDs, nil
}

Просмотреть файл

@ -129,6 +129,28 @@
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), parameters('databaseName'))]"
]
},
{
"properties": {
"resource": {
"id": "Portal",
"partitionKey": {
"paths": [
"/id"
],
"kind": "Hash"
},
"defaultTtl": -1
},
"options": {}
},
"name": "[concat(parameters('databaseAccountName'), '/', parameters('databaseName'), '/Portal')]",
"type": "Microsoft.DocumentDB/databaseAccounts/sqlDatabases/containers",
"location": "[resourceGroup().location]",
"apiVersion": "2019-08-01",
"dependsOn": [
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), parameters('databaseName'))]"
]
},
{
"properties": {
"resource": {

Просмотреть файл

@ -116,6 +116,47 @@
"location": "[resourceGroup().location]",
"apiVersion": "2016-10-01"
},
{
"properties": {
"tenantId": "[subscription().tenantId]",
"sku": {
"family": "A",
"name": "standard"
},
"accessPolicies": [
{
"tenantId": "[subscription().tenantId]",
"objectId": "[parameters('rpServicePrincipalId')]",
"permissions": {
"secrets": [
"get"
]
}
},
{
"tenantId": "[subscription().tenantId]",
"objectId": "[parameters('adminObjectId')]",
"permissions": {
"secrets": [
"set",
"list"
],
"certificates": [
"delete",
"get",
"import",
"list"
]
}
}
],
"enableSoftDelete": true
},
"name": "[concat(parameters('keyvaultPrefix'), '-por')]",
"type": "Microsoft.KeyVault/vaults",
"location": "[resourceGroup().location]",
"apiVersion": "2016-10-01"
},
{
"properties": {
"tenantId": "[subscription().tenantId]",

Просмотреть файл

@ -38,6 +38,15 @@
"mdsdEnvironment": {
"value": ""
},
"portalAccessGroupIds": {
"value": ""
},
"portalClientId": {
"value": ""
},
"portalElevatedGroupIds": {
"value": ""
},
"rpImage": {
"value": ""
},

Просмотреть файл

@ -8,6 +8,9 @@
"extraClusterKeyvaultAccessPolicies": {
"value": []
},
"extraPortalKeyvaultAccessPolicies": {
"value": []
},
"extraServiceKeyvaultAccessPolicies": {
"value": []
},

Просмотреть файл

@ -34,6 +34,17 @@
}
}
],
"portalKeyvaultAccessPolicies": [
{
"tenantId": "[subscription().tenantId]",
"objectId": "[parameters('rpServicePrincipalId')]",
"permissions": {
"secrets": [
"get"
]
}
}
],
"serviceKeyvaultAccessPolicies": [
{
"tenantId": "[subscription().tenantId]",
@ -55,6 +66,10 @@
"type": "array",
"defaultValue": []
},
"extraPortalKeyvaultAccessPolicies": {
"type": "array",
"defaultValue": []
},
"extraServiceKeyvaultAccessPolicies": {
"type": "array",
"defaultValue": []
@ -140,6 +155,22 @@
"condition": "[parameters('fullDeploy')]",
"apiVersion": "2016-10-01"
},
{
"properties": {
"tenantId": "[subscription().tenantId]",
"sku": {
"family": "A",
"name": "standard"
},
"accessPolicies": "[concat(variables('portalKeyvaultAccessPolicies'), parameters('extraPortalKeyvaultAccessPolicies'))]",
"enableSoftDelete": true
},
"name": "[concat(parameters('keyvaultPrefix'), '-por')]",
"type": "Microsoft.KeyVault/vaults",
"location": "[resourceGroup().location]",
"condition": "[parameters('fullDeploy')]",
"apiVersion": "2016-10-01"
},
{
"properties": {
"tenantId": "[subscription().tenantId]",

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -182,6 +182,31 @@ locations.
>/dev/null
```
1. Create an AAD application which will fake up the portal client.
This application requires client certificate authentication to be enabled. A
suitable key/certificate file can be generated using the following helper
utility:
```bash
go run ./hack/genkey -client portal-client
mv portal-client.* secrets
```
```bash
AZURE_PORTAL_CLIENT_ID="$(az ad app create \
--display-name aro-v4-portal-shared \
--identifier-uris "https://$(uuidgen)/" \
--reply-urls "https://localhost:8444/callback" \
--query appId \
-o tsv)"
az ad app credential reset \
--id "$AZURE_PORTAL_CLIENT_ID" \
--cert "$(base64 -w0 <secrets/portal-client.crt)" >/dev/null
```
TODO: more steps are needed to configure aro-v4-portal-shared.
## Certificates
@ -258,6 +283,9 @@ locations.
export AZURE_ARM_CLIENT_ID='$AZURE_ARM_CLIENT_ID'
export AZURE_ARM_CLIENT_SECRET='$AZURE_ARM_CLIENT_SECRET'
export AZURE_FP_CLIENT_ID='$AZURE_FP_CLIENT_ID'
export AZURE_PORTAL_CLIENT_ID='$AZURE_PORTAL_CLIENT_ID'
export AZURE_PORTAL_ACCESS_GROUP_IDS='$ADMIN_OBJECT_ID'
export AZURE_PORTAL_ELEVATED_GROUP_IDS='$ADMIN_OBJECT_ID'
export AZURE_CLIENT_ID='$AZURE_CLIENT_ID'
export AZURE_CLIENT_SECRET='$AZURE_CLIENT_SECRET'
export AZURE_RP_CLIENT_ID='$AZURE_RP_CLIENT_ID'

Просмотреть файл

@ -84,6 +84,14 @@ import_certs_secrets() {
--vault-name "$KEYVAULT_PREFIX-svc" \
--name rp-server \
--file secrets/localhost.pem
az keyvault certificate import \
--vault-name "$KEYVAULT_PREFIX-por" \
--name portal-server \
--file secrets/localhost.pem
az keyvault certificate import \
--vault-name "$KEYVAULT_PREFIX-por" \
--name portal-client \
--file secrets/portal-client.pem
az keyvault certificate import \
--vault-name "$KEYVAULT_PREFIX-svc" \
--name cluster-mdsd \
@ -104,6 +112,22 @@ import_certs_secrets() {
--vault-name "$KEYVAULT_PREFIX-svc" \
--name fe-encryption-key \
--value "$(openssl rand -base64 32)"
az keyvault secret list \
--vault-name "$KEYVAULT_PREFIX-por" \
--query '[].name' \
-o tsv | grep -q ^portal-session-key$ || \
az keyvault secret set \
--vault-name "$KEYVAULT_PREFIX-por" \
--name portal-session-key \
--value "$(openssl rand -base64 32)"
az keyvault secret list \
--vault-name "$KEYVAULT_PREFIX-por" \
--query '[].name' \
-o tsv | grep -q ^portal-sshkey$ || \
az keyvault secret set \
--vault-name "$KEYVAULT_PREFIX-por" \
--name portal-sshkey \
--value "$(openssl genpkey -algorithm rsa -pkeyopt rsa_keygen_bits:2048 -outform der | base64 -w0)"
}
update_parent_domain_dns_zone() {

27
pkg/api/portal.go Normal file
Просмотреть файл

@ -0,0 +1,27 @@
package api
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
// Portal represents a portal
type Portal struct {
MissingFields
Username string `json:"username"`
ID string `json:"id,omitempty"`
SSH *SSH `json:"ssh,omitempty"`
Kubeconfig *Kubeconfig `json:"kubeconfig,omitempty"`
}
type SSH struct {
MissingFields
Master int `json:"master,omitempty"`
}
type Kubeconfig struct {
MissingFields
Elevated bool `json:"elevated,omitempty"`
}

38
pkg/api/portaldocument.go Normal file
Просмотреть файл

@ -0,0 +1,38 @@
package api
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
// PortalDocuments represents portal documents.
// pkg/database/cosmosdb requires its definition.
type PortalDocuments struct {
Count int `json:"_count,omitempty"`
ResourceID string `json:"_rid,omitempty"`
PortalDocuments []*PortalDocument `json:"Documents,omitempty"`
}
func (c *PortalDocuments) String() string {
return encodeJSON(c)
}
// PortalDocument represents a portal document.
// pkg/database/cosmosdb requires its definition.
type PortalDocument struct {
MissingFields
ID string `json:"id,omitempty"`
ResourceID string `json:"_rid,omitempty"`
Timestamp int `json:"_ts,omitempty"`
Self string `json:"_self,omitempty"`
ETag string `json:"_etag,omitempty"`
Attachments string `json:"_attachments,omitempty"`
TTL int `json:"ttl,omitempty"`
LSN int `json:"_lsn,omitempty"`
Metadata map[string]interface{} `json:"_metadata,omitempty"`
Portal *Portal `json:"portal,omitempty"`
}
func (c *PortalDocument) String() string {
return encodeJSON(c)
}

Просмотреть файл

@ -9,6 +9,7 @@ import (
// Regular expressions used to validate the format of resource names and IDs acceptable by API.
var (
RxClusterID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]/providers/Microsoft\.RedHatOpenShift/openShiftClusters/[-a-z0-9_().]{0,89}[-a-z0-9_()]$`)
RxResourceGroupID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]$`)
RxSubnetID = regexp.MustCompile(`(?i)^/subscriptions/[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/resourceGroups/[-a-z0-9_().]{0,89}[-a-z0-9_()]/providers/Microsoft\.Network/virtualNetworks/[-a-z0-9_.]{2,64}/subnets/[-a-z0-9_.]{2,80}$`)
RxDomainName = regexp.MustCompile(`^` +

Просмотреть файл

@ -257,5 +257,5 @@ func (m *manager) saveGraph(ctx context.Context, g graph) error {
return err
}
return graph.CreateBlockBlobFromReader(bytes.NewReader([]byte(output)), nil)
return graph.CreateBlockBlobFromReader(bytes.NewReader(output), nil)
}

Просмотреть файл

@ -0,0 +1,313 @@
// Code generated by github.com/jim-minter/go-cosmosdb, DO NOT EDIT.
package cosmosdb
import (
"context"
"net/http"
"strconv"
"strings"
pkg "github.com/Azure/ARO-RP/pkg/api"
)
type portalDocumentClient struct {
*databaseClient
path string
}
// PortalDocumentClient is a portalDocument client
type PortalDocumentClient interface {
Create(context.Context, string, *pkg.PortalDocument, *Options) (*pkg.PortalDocument, error)
List(*Options) PortalDocumentIterator
ListAll(context.Context, *Options) (*pkg.PortalDocuments, error)
Get(context.Context, string, string, *Options) (*pkg.PortalDocument, error)
Replace(context.Context, string, *pkg.PortalDocument, *Options) (*pkg.PortalDocument, error)
Delete(context.Context, string, *pkg.PortalDocument, *Options) error
Query(string, *Query, *Options) PortalDocumentRawIterator
QueryAll(context.Context, string, *Query, *Options) (*pkg.PortalDocuments, error)
ChangeFeed(*Options) PortalDocumentIterator
}
type portalDocumentChangeFeedIterator struct {
*portalDocumentClient
continuation string
options *Options
}
type portalDocumentListIterator struct {
*portalDocumentClient
continuation string
done bool
options *Options
}
type portalDocumentQueryIterator struct {
*portalDocumentClient
partitionkey string
query *Query
continuation string
done bool
options *Options
}
// PortalDocumentIterator is a portalDocument iterator
type PortalDocumentIterator interface {
Next(context.Context, int) (*pkg.PortalDocuments, error)
Continuation() string
}
// PortalDocumentRawIterator is a portalDocument raw iterator
type PortalDocumentRawIterator interface {
PortalDocumentIterator
NextRaw(context.Context, int, interface{}) error
}
// NewPortalDocumentClient returns a new portalDocument client
func NewPortalDocumentClient(collc CollectionClient, collid string) PortalDocumentClient {
return &portalDocumentClient{
databaseClient: collc.(*collectionClient).databaseClient,
path: collc.(*collectionClient).path + "/colls/" + collid,
}
}
func (c *portalDocumentClient) all(ctx context.Context, i PortalDocumentIterator) (*pkg.PortalDocuments, error) {
allportalDocuments := &pkg.PortalDocuments{}
for {
portalDocuments, err := i.Next(ctx, -1)
if err != nil {
return nil, err
}
if portalDocuments == nil {
break
}
allportalDocuments.Count += portalDocuments.Count
allportalDocuments.ResourceID = portalDocuments.ResourceID
allportalDocuments.PortalDocuments = append(allportalDocuments.PortalDocuments, portalDocuments.PortalDocuments...)
}
return allportalDocuments, nil
}
func (c *portalDocumentClient) Create(ctx context.Context, partitionkey string, newportalDocument *pkg.PortalDocument, options *Options) (portalDocument *pkg.PortalDocument, err error) {
headers := http.Header{}
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
if options == nil {
options = &Options{}
}
options.NoETag = true
err = c.setOptions(options, newportalDocument, headers)
if err != nil {
return
}
err = c.do(ctx, http.MethodPost, c.path+"/docs", "docs", c.path, http.StatusCreated, &newportalDocument, &portalDocument, headers)
return
}
func (c *portalDocumentClient) List(options *Options) PortalDocumentIterator {
continuation := ""
if options != nil {
continuation = options.Continuation
}
return &portalDocumentListIterator{portalDocumentClient: c, options: options, continuation: continuation}
}
func (c *portalDocumentClient) ListAll(ctx context.Context, options *Options) (*pkg.PortalDocuments, error) {
return c.all(ctx, c.List(options))
}
func (c *portalDocumentClient) Get(ctx context.Context, partitionkey, portalDocumentid string, options *Options) (portalDocument *pkg.PortalDocument, err error) {
headers := http.Header{}
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
err = c.setOptions(options, nil, headers)
if err != nil {
return
}
err = c.do(ctx, http.MethodGet, c.path+"/docs/"+portalDocumentid, "docs", c.path+"/docs/"+portalDocumentid, http.StatusOK, nil, &portalDocument, headers)
return
}
func (c *portalDocumentClient) Replace(ctx context.Context, partitionkey string, newportalDocument *pkg.PortalDocument, options *Options) (portalDocument *pkg.PortalDocument, err error) {
headers := http.Header{}
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
err = c.setOptions(options, newportalDocument, headers)
if err != nil {
return
}
err = c.do(ctx, http.MethodPut, c.path+"/docs/"+newportalDocument.ID, "docs", c.path+"/docs/"+newportalDocument.ID, http.StatusOK, &newportalDocument, &portalDocument, headers)
return
}
func (c *portalDocumentClient) Delete(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (err error) {
headers := http.Header{}
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+partitionkey+`"]`)
err = c.setOptions(options, portalDocument, headers)
if err != nil {
return
}
err = c.do(ctx, http.MethodDelete, c.path+"/docs/"+portalDocument.ID, "docs", c.path+"/docs/"+portalDocument.ID, http.StatusNoContent, nil, nil, headers)
return
}
func (c *portalDocumentClient) Query(partitionkey string, query *Query, options *Options) PortalDocumentRawIterator {
continuation := ""
if options != nil {
continuation = options.Continuation
}
return &portalDocumentQueryIterator{portalDocumentClient: c, partitionkey: partitionkey, query: query, options: options, continuation: continuation}
}
func (c *portalDocumentClient) QueryAll(ctx context.Context, partitionkey string, query *Query, options *Options) (*pkg.PortalDocuments, error) {
return c.all(ctx, c.Query(partitionkey, query, options))
}
func (c *portalDocumentClient) ChangeFeed(options *Options) PortalDocumentIterator {
continuation := ""
if options != nil {
continuation = options.Continuation
}
return &portalDocumentChangeFeedIterator{portalDocumentClient: c, options: options, continuation: continuation}
}
func (c *portalDocumentClient) setOptions(options *Options, portalDocument *pkg.PortalDocument, headers http.Header) error {
if options == nil {
return nil
}
if portalDocument != nil && !options.NoETag {
if portalDocument.ETag == "" {
return ErrETagRequired
}
headers.Set("If-Match", portalDocument.ETag)
}
if len(options.PreTriggers) > 0 {
headers.Set("X-Ms-Documentdb-Pre-Trigger-Include", strings.Join(options.PreTriggers, ","))
}
if len(options.PostTriggers) > 0 {
headers.Set("X-Ms-Documentdb-Post-Trigger-Include", strings.Join(options.PostTriggers, ","))
}
if len(options.PartitionKeyRangeID) > 0 {
headers.Set("X-Ms-Documentdb-PartitionKeyRangeID", options.PartitionKeyRangeID)
}
return nil
}
func (i *portalDocumentChangeFeedIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
headers := http.Header{}
headers.Set("A-IM", "Incremental feed")
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
if i.continuation != "" {
headers.Set("If-None-Match", i.continuation)
}
err = i.setOptions(i.options, nil, headers)
if err != nil {
return
}
err = i.do(ctx, http.MethodGet, i.path+"/docs", "docs", i.path, http.StatusOK, nil, &portalDocuments, headers)
if IsErrorStatusCode(err, http.StatusNotModified) {
err = nil
}
if err != nil {
return
}
i.continuation = headers.Get("Etag")
return
}
func (i *portalDocumentChangeFeedIterator) Continuation() string {
return i.continuation
}
func (i *portalDocumentListIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
if i.done {
return
}
headers := http.Header{}
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
if i.continuation != "" {
headers.Set("X-Ms-Continuation", i.continuation)
}
err = i.setOptions(i.options, nil, headers)
if err != nil {
return
}
err = i.do(ctx, http.MethodGet, i.path+"/docs", "docs", i.path, http.StatusOK, nil, &portalDocuments, headers)
if err != nil {
return
}
i.continuation = headers.Get("X-Ms-Continuation")
i.done = i.continuation == ""
return
}
func (i *portalDocumentListIterator) Continuation() string {
return i.continuation
}
func (i *portalDocumentQueryIterator) Next(ctx context.Context, maxItemCount int) (portalDocuments *pkg.PortalDocuments, err error) {
err = i.NextRaw(ctx, maxItemCount, &portalDocuments)
return
}
func (i *portalDocumentQueryIterator) NextRaw(ctx context.Context, maxItemCount int, raw interface{}) (err error) {
if i.done {
return
}
headers := http.Header{}
headers.Set("X-Ms-Max-Item-Count", strconv.Itoa(maxItemCount))
headers.Set("X-Ms-Documentdb-Isquery", "True")
headers.Set("Content-Type", "application/query+json")
if i.partitionkey != "" {
headers.Set("X-Ms-Documentdb-Partitionkey", `["`+i.partitionkey+`"]`)
} else {
headers.Set("X-Ms-Documentdb-Query-Enablecrosspartition", "True")
}
if i.continuation != "" {
headers.Set("X-Ms-Continuation", i.continuation)
}
err = i.setOptions(i.options, nil, headers)
if err != nil {
return
}
err = i.do(ctx, http.MethodPost, i.path+"/docs", "docs", i.path, http.StatusOK, &i.query, &raw, headers)
if err != nil {
return
}
i.continuation = headers.Get("X-Ms-Continuation")
i.done = i.continuation == ""
return
}
func (i *portalDocumentQueryIterator) Continuation() string {
return i.continuation
}

Просмотреть файл

@ -0,0 +1,360 @@
// Code generated by github.com/jim-minter/go-cosmosdb, DO NOT EDIT.
package cosmosdb
import (
"context"
"fmt"
"net/http"
"sync"
"github.com/ugorji/go/codec"
pkg "github.com/Azure/ARO-RP/pkg/api"
)
type fakePortalDocumentTriggerHandler func(context.Context, *pkg.PortalDocument) error
type fakePortalDocumentQueryHandler func(PortalDocumentClient, *Query, *Options) PortalDocumentRawIterator
var _ PortalDocumentClient = &FakePortalDocumentClient{}
// NewFakePortalDocumentClient returns a FakePortalDocumentClient
func NewFakePortalDocumentClient(h *codec.JsonHandle) *FakePortalDocumentClient {
return &FakePortalDocumentClient{
portalDocuments: make(map[string][]byte),
triggerHandlers: make(map[string]fakePortalDocumentTriggerHandler),
queryHandlers: make(map[string]fakePortalDocumentQueryHandler),
jsonHandle: h,
lock: &sync.RWMutex{},
}
}
// FakePortalDocumentClient is a FakePortalDocumentClient
type FakePortalDocumentClient struct {
portalDocuments map[string][]byte
jsonHandle *codec.JsonHandle
lock *sync.RWMutex
triggerHandlers map[string]fakePortalDocumentTriggerHandler
queryHandlers map[string]fakePortalDocumentQueryHandler
sorter func([]*pkg.PortalDocument)
// returns true if documents conflict
conflictChecker func(*pkg.PortalDocument, *pkg.PortalDocument) bool
// err, if not nil, is an error to return when attempting to communicate
// with this Client
err error
}
func (c *FakePortalDocumentClient) decodePortalDocument(s []byte) (portalDocument *pkg.PortalDocument, err error) {
err = codec.NewDecoderBytes(s, c.jsonHandle).Decode(&portalDocument)
return
}
func (c *FakePortalDocumentClient) encodePortalDocument(portalDocument *pkg.PortalDocument) (b []byte, err error) {
err = codec.NewEncoderBytes(&b, c.jsonHandle).Encode(portalDocument)
return
}
// SetError sets or unsets an error that will be returned on any
// FakePortalDocumentClient method invocation
func (c *FakePortalDocumentClient) SetError(err error) {
c.lock.Lock()
defer c.lock.Unlock()
c.err = err
}
// SetSorter sets or unsets a sorter function which will be used to sort values
// returned by List() for test stability
func (c *FakePortalDocumentClient) SetSorter(sorter func([]*pkg.PortalDocument)) {
c.lock.Lock()
defer c.lock.Unlock()
c.sorter = sorter
}
// SetConflictChecker sets or unsets a function which can be used to validate
// additional unique keys in a PortalDocument
func (c *FakePortalDocumentClient) SetConflictChecker(conflictChecker func(*pkg.PortalDocument, *pkg.PortalDocument) bool) {
c.lock.Lock()
defer c.lock.Unlock()
c.conflictChecker = conflictChecker
}
// SetTriggerHandler sets or unsets a trigger handler
func (c *FakePortalDocumentClient) SetTriggerHandler(triggerName string, trigger fakePortalDocumentTriggerHandler) {
c.lock.Lock()
defer c.lock.Unlock()
c.triggerHandlers[triggerName] = trigger
}
// SetQueryHandler sets or unsets a query handler
func (c *FakePortalDocumentClient) SetQueryHandler(queryName string, query fakePortalDocumentQueryHandler) {
c.lock.Lock()
defer c.lock.Unlock()
c.queryHandlers[queryName] = query
}
func (c *FakePortalDocumentClient) deepCopy(portalDocument *pkg.PortalDocument) (*pkg.PortalDocument, error) {
b, err := c.encodePortalDocument(portalDocument)
if err != nil {
return nil, err
}
return c.decodePortalDocument(b)
}
func (c *FakePortalDocumentClient) apply(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options, isCreate bool) (*pkg.PortalDocument, error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.err != nil {
return nil, c.err
}
portalDocument, err := c.deepCopy(portalDocument) // copy now because pretriggers can mutate portalDocument
if err != nil {
return nil, err
}
if options != nil {
err := c.processPreTriggers(ctx, portalDocument, options)
if err != nil {
return nil, err
}
}
_, exists := c.portalDocuments[portalDocument.ID]
if isCreate && exists {
return nil, &Error{
StatusCode: http.StatusConflict,
Message: "Entity with the specified id already exists in the system",
}
}
if !isCreate && !exists {
return nil, &Error{StatusCode: http.StatusNotFound}
}
if c.conflictChecker != nil {
for id := range c.portalDocuments {
portalDocumentToCheck, err := c.decodePortalDocument(c.portalDocuments[id])
if err != nil {
return nil, err
}
if c.conflictChecker(portalDocumentToCheck, portalDocument) {
return nil, &Error{
StatusCode: http.StatusConflict,
Message: "Entity with the specified id already exists in the system",
}
}
}
}
b, err := c.encodePortalDocument(portalDocument)
if err != nil {
return nil, err
}
c.portalDocuments[portalDocument.ID] = b
return portalDocument, nil
}
// Create creates a PortalDocument in the database
func (c *FakePortalDocumentClient) Create(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (*pkg.PortalDocument, error) {
return c.apply(ctx, partitionkey, portalDocument, options, true)
}
// Replace replaces a PortalDocument in the database
func (c *FakePortalDocumentClient) Replace(ctx context.Context, partitionkey string, portalDocument *pkg.PortalDocument, options *Options) (*pkg.PortalDocument, error) {
return c.apply(ctx, partitionkey, portalDocument, options, false)
}
// List returns a PortalDocumentIterator to list all PortalDocuments in the database
func (c *FakePortalDocumentClient) List(*Options) PortalDocumentIterator {
c.lock.RLock()
defer c.lock.RUnlock()
if c.err != nil {
return NewFakePortalDocumentErroringRawIterator(c.err)
}
portalDocuments := make([]*pkg.PortalDocument, 0, len(c.portalDocuments))
for _, d := range c.portalDocuments {
r, err := c.decodePortalDocument(d)
if err != nil {
return NewFakePortalDocumentErroringRawIterator(err)
}
portalDocuments = append(portalDocuments, r)
}
if c.sorter != nil {
c.sorter(portalDocuments)
}
return NewFakePortalDocumentIterator(portalDocuments, 0)
}
// ListAll lists all PortalDocuments in the database
func (c *FakePortalDocumentClient) ListAll(ctx context.Context, options *Options) (*pkg.PortalDocuments, error) {
iter := c.List(options)
return iter.Next(ctx, -1)
}
// Get gets a PortalDocument from the database
func (c *FakePortalDocumentClient) Get(ctx context.Context, partitionkey string, id string, options *Options) (*pkg.PortalDocument, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.err != nil {
return nil, c.err
}
portalDocument, exists := c.portalDocuments[id]
if !exists {
return nil, &Error{StatusCode: http.StatusNotFound}
}
return c.decodePortalDocument(portalDocument)
}
// Delete deletes a PortalDocument from the database
func (c *FakePortalDocumentClient) Delete(ctx context.Context, partitionKey string, portalDocument *pkg.PortalDocument, options *Options) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.err != nil {
return c.err
}
_, exists := c.portalDocuments[portalDocument.ID]
if !exists {
return &Error{StatusCode: http.StatusNotFound}
}
delete(c.portalDocuments, portalDocument.ID)
return nil
}
// ChangeFeed is unimplemented
func (c *FakePortalDocumentClient) ChangeFeed(*Options) PortalDocumentIterator {
c.lock.RLock()
defer c.lock.RUnlock()
if c.err != nil {
return NewFakePortalDocumentErroringRawIterator(c.err)
}
return NewFakePortalDocumentErroringRawIterator(ErrNotImplemented)
}
func (c *FakePortalDocumentClient) processPreTriggers(ctx context.Context, portalDocument *pkg.PortalDocument, options *Options) error {
for _, triggerName := range options.PreTriggers {
if triggerHandler := c.triggerHandlers[triggerName]; triggerHandler != nil {
err := triggerHandler(ctx, portalDocument)
if err != nil {
return err
}
} else {
return ErrNotImplemented
}
}
return nil
}
// Query calls a query handler to implement database querying
func (c *FakePortalDocumentClient) Query(name string, query *Query, options *Options) PortalDocumentRawIterator {
c.lock.RLock()
defer c.lock.RUnlock()
if c.err != nil {
return NewFakePortalDocumentErroringRawIterator(c.err)
}
if queryHandler := c.queryHandlers[query.Query]; queryHandler != nil {
return queryHandler(c, query, options)
}
return NewFakePortalDocumentErroringRawIterator(ErrNotImplemented)
}
// QueryAll calls a query handler to implement database querying
func (c *FakePortalDocumentClient) QueryAll(ctx context.Context, partitionkey string, query *Query, options *Options) (*pkg.PortalDocuments, error) {
iter := c.Query("", query, options)
return iter.Next(ctx, -1)
}
func NewFakePortalDocumentIterator(portalDocuments []*pkg.PortalDocument, continuation int) PortalDocumentRawIterator {
return &fakePortalDocumentIterator{portalDocuments: portalDocuments, continuation: continuation}
}
type fakePortalDocumentIterator struct {
portalDocuments []*pkg.PortalDocument
continuation int
done bool
}
func (i *fakePortalDocumentIterator) NextRaw(ctx context.Context, maxItemCount int, out interface{}) error {
return ErrNotImplemented
}
func (i *fakePortalDocumentIterator) Next(ctx context.Context, maxItemCount int) (*pkg.PortalDocuments, error) {
if i.done {
return nil, nil
}
var portalDocuments []*pkg.PortalDocument
if maxItemCount == -1 {
portalDocuments = i.portalDocuments[i.continuation:]
i.continuation = len(i.portalDocuments)
i.done = true
} else {
max := i.continuation + maxItemCount
if max > len(i.portalDocuments) {
max = len(i.portalDocuments)
}
portalDocuments = i.portalDocuments[i.continuation:max]
i.continuation += max
i.done = i.Continuation() == ""
}
return &pkg.PortalDocuments{
PortalDocuments: portalDocuments,
Count: len(portalDocuments),
}, nil
}
func (i *fakePortalDocumentIterator) Continuation() string {
if i.continuation >= len(i.portalDocuments) {
return ""
}
return fmt.Sprintf("%d", i.continuation)
}
// NewFakePortalDocumentErroringRawIterator returns a PortalDocumentRawIterator which
// whose methods return the given error
func NewFakePortalDocumentErroringRawIterator(err error) PortalDocumentRawIterator {
return &fakePortalDocumentErroringRawIterator{err: err}
}
type fakePortalDocumentErroringRawIterator struct {
err error
}
func (i *fakePortalDocumentErroringRawIterator) Next(ctx context.Context, maxItemCount int) (*pkg.PortalDocuments, error) {
return nil, i.err
}
func (i *fakePortalDocumentErroringRawIterator) NextRaw(context.Context, int, interface{}) error {
return i.err
}
func (i *fakePortalDocumentErroringRawIterator) Continuation() string {
return ""
}

Просмотреть файл

@ -30,6 +30,7 @@ const (
collBilling = "Billing"
collMonitors = "Monitors"
collOpenShiftClusters = "OpenShiftClusters"
collPortal = "Portal"
collSubscriptions = "Subscriptions"
)

Просмотреть файл

@ -43,6 +43,7 @@ type OpenShiftClusters interface {
Delete(context.Context, *api.OpenShiftClusterDocument) error
ChangeFeed() cosmosdb.OpenShiftClusterDocumentIterator
List(string) cosmosdb.OpenShiftClusterDocumentIterator
ListAll(context.Context) (*api.OpenShiftClusterDocuments, error)
ListByPrefix(string, string, string) (cosmosdb.OpenShiftClusterDocumentIterator, error)
Dequeue(context.Context) (*api.OpenShiftClusterDocument, error)
Lease(context.Context, string) (*api.OpenShiftClusterDocument, error)
@ -246,6 +247,10 @@ func (c *openShiftClusters) List(continuation string) cosmosdb.OpenShiftClusterD
return c.c.List(&cosmosdb.Options{Continuation: continuation})
}
func (c *openShiftClusters) ListAll(ctx context.Context) (*api.OpenShiftClusterDocuments, error) {
return c.c.ListAll(ctx, nil)
}
func (c *openShiftClusters) ListByPrefix(subscriptionID, prefix, continuation string) (cosmosdb.OpenShiftClusterDocumentIterator, error) {
if prefix != strings.ToLower(prefix) {
return nil, fmt.Errorf("prefix %q is not lower case", prefix)

75
pkg/database/portal.go Normal file
Просмотреть файл

@ -0,0 +1,75 @@
package database
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
"github.com/Azure/ARO-RP/pkg/util/deployment"
)
type portals struct {
c cosmosdb.PortalDocumentClient
}
// Portal is the database interface for PortalDocuments
type Portal interface {
Create(context.Context, *api.PortalDocument) (*api.PortalDocument, error)
Get(context.Context, string) (*api.PortalDocument, error)
Delete(context.Context, *api.PortalDocument) error
}
// NewPortal returns a new Portal
func NewPortal(ctx context.Context, deploymentMode deployment.Mode, dbc cosmosdb.DatabaseClient) (Portal, error) {
dbid, err := databaseName(deploymentMode)
if err != nil {
return nil, err
}
collc := cosmosdb.NewCollectionClient(dbc, dbid)
documentClient := cosmosdb.NewPortalDocumentClient(collc, collPortal)
return NewPortalWithProvidedClient(documentClient), nil
}
func NewPortalWithProvidedClient(client cosmosdb.PortalDocumentClient) Portal {
return &portals{
c: client,
}
}
func (c *portals) Create(ctx context.Context, doc *api.PortalDocument) (*api.PortalDocument, error) {
if doc.ID != strings.ToLower(doc.ID) {
return nil, fmt.Errorf("id %q is not lower case", doc.ID)
}
doc, err := c.c.Create(ctx, doc.ID, doc, nil)
if err, ok := err.(*cosmosdb.Error); ok && err.StatusCode == http.StatusConflict {
err.StatusCode = http.StatusPreconditionFailed
}
return doc, err
}
func (c *portals) Get(ctx context.Context, id string) (*api.PortalDocument, error) {
if id != strings.ToLower(id) {
return nil, fmt.Errorf("id %q is not lower case", id)
}
return c.c.Get(ctx, id, id, nil)
}
func (c *portals) Delete(ctx context.Context, doc *api.PortalDocument) error {
if doc.ID != strings.ToLower(doc.ID) {
return fmt.Errorf("id %q is not lower case", doc.ID)
}
return c.c.Delete(ctx, doc.ID, doc, nil)
}

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -7,6 +7,7 @@ import (
"fmt"
"io/ioutil"
"reflect"
"strings"
"github.com/ghodss/yaml"
)
@ -36,6 +37,7 @@ type Configuration struct {
DatabaseAccountName *string `json:"databaseAccountName,omitempty" value:"required"`
ExtraClusterKeyvaultAccessPolicies []interface{} `json:"extraClusterKeyvaultAccessPolicies,omitempty" value:"required"`
ExtraCosmosDBIPs []string `json:"extraCosmosDBIPs,omitempty" value:"required"`
ExtraPortalKeyvaultAccessPolicies []interface{} `json:"extraPortalKeyvaultAccessPolicies,omitempty" value:"required"`
ExtraServiceKeyvaultAccessPolicies []interface{} `json:"extraServiceKeyvaultAccessPolicies,omitempty" value:"required"`
FPServerCertCommonName *string `json:"fpServerCertCommonName,omitempty"`
FPServicePrincipalID *string `json:"fpServicePrincipalId,omitempty" value:"required"`
@ -46,6 +48,9 @@ type Configuration struct {
MDMFrontendURL *string `json:"mdmFrontendUrl,omitempty" value:"required"`
MDSDConfigVersion *string `json:"mdsdConfigVersion,omitempty" value:"required"`
MDSDEnvironment *string `json:"mdsdEnvironment,omitempty" value:"required"`
PortalAccessGroupIDs *string `json:"portalAccessGroupIds,omitempty" value:"required"`
PortalClientID *string `json:"portalClientId,omitempty" value:"required"`
PortalElevatedGroupIDs *string `json:"portalElevatedGroupIds,omitempty" value:"required"`
RPImagePrefix *string `json:"rpImagePrefix,omitempty" value:"required"`
RPMode *string `json:"rpMode,omitempty"`
RPNSGSourceAddressPrefixes []string `json:"rpNsgSourceAddressPrefixes,omitempty" value:"required"`
@ -119,5 +124,5 @@ func (conf *RPConfig) validate() error {
return nil
}
return fmt.Errorf("Configuration has missing fields: %s", missingFields)
return fmt.Errorf("configuration has missing fields: %s", strings.Join(missingFields, ","))
}

Просмотреть файл

@ -5,7 +5,6 @@ package deploy
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"testing"
@ -114,87 +113,25 @@ func TestConfigNilable(t *testing.T) {
}
func TestConfigRequiredValues(t *testing.T) {
AdminAPICABundle := "AdminAPICABundle"
ACRResourceID := "ACRResourceID"
ExtraCosmosDBIPs := "ExtraCosmosDBIPs"
MDMFrontendURL := "MDMFrontendURL"
AdminAPIClientCertCommonName := "AdminAPIClientCertCommonName"
ClusterParentDomainName := "ClusterParentDomainName"
DatabaseAccountName := "DatabaseAccountName"
FPServerCertCommonName := "FPServerCertCommonName"
FPServicePrincipalID := "FPServicePrincipalID"
GlobalResourceGroupName := "GlobalResourceGroupName"
GlobalResourceGroupLocation := "GlobalResourceGroupLocation"
GlobalSubscriptionID := "GlobalSubscriptionID"
KeyvaultPrefix := "KeyvaultPrefix"
MDSDConfigVersion := "MDSDConfigVersion"
MDSDEnvironment := "MDSDEnvironment"
RPImagePrefix := "RPImagePrefix"
RPMode := "RPMode"
RPParentDomainName := "RPParentDomainName"
RPVersionStorageAccountName := "RPVersionStorageAccountName"
SSHPublicKey := "SSHPublicKey"
SubscriptionResourceGroupName := "SubscriptionResourceGroupName"
SubscriptionResourceGroupLocation := "SubscriptionResourceGroupLocation"
StorageAccountDomain := "StorageAccountDomain"
VMSize := "VMSize"
for _, tt := range []struct {
name string
config RPConfig
expect error
wantErr string
}{
{
name: "valid config",
config: RPConfig{
Configuration: &Configuration{
ACRResourceID: &ACRResourceID,
AdminAPICABundle: &AdminAPICABundle,
ExtraCosmosDBIPs: []string{ExtraCosmosDBIPs},
MDMFrontendURL: &MDMFrontendURL,
ACRReplicaDisabled: to.BoolPtr(true),
AdminAPIClientCertCommonName: &AdminAPIClientCertCommonName,
ClusterParentDomainName: &ClusterParentDomainName,
DatabaseAccountName: &DatabaseAccountName,
ExtraClusterKeyvaultAccessPolicies: []interface{}{},
ExtraServiceKeyvaultAccessPolicies: []interface{}{},
FPServerCertCommonName: &FPServerCertCommonName,
FPServicePrincipalID: &FPServicePrincipalID,
GlobalResourceGroupName: &GlobalResourceGroupName,
GlobalResourceGroupLocation: &GlobalResourceGroupLocation,
GlobalSubscriptionID: &GlobalSubscriptionID,
KeyvaultPrefix: &KeyvaultPrefix,
MDSDConfigVersion: &MDSDConfigVersion,
MDSDEnvironment: &MDSDEnvironment,
RPImagePrefix: &RPImagePrefix,
RPMode: &RPMode,
RPNSGSourceAddressPrefixes: []string{},
RPParentDomainName: &RPParentDomainName,
RPVersionStorageAccountName: &RPVersionStorageAccountName,
SSHPublicKey: &SSHPublicKey,
SubscriptionResourceGroupName: &SubscriptionResourceGroupName,
SubscriptionResourceGroupLocation: &SubscriptionResourceGroupLocation,
StorageAccountDomain: &StorageAccountDomain,
VMSize: &VMSize,
},
},
expect: nil,
},
{
name: "invalid config",
config: RPConfig{
Configuration: &Configuration{
ACRResourceID: &ACRResourceID,
AdminAPICABundle: &AdminAPICABundle,
ExtraCosmosDBIPs: []string{ExtraCosmosDBIPs},
Configuration: &Configuration{},
},
},
expect: fmt.Errorf("Configuration has missing fields: %s", "[RPVersionStorageAccountName AdminAPIClientCertCommonName ClusterParentDomainName DatabaseAccountName ExtraClusterKeyvaultAccessPolicies ExtraServiceKeyvaultAccessPolicies FPServicePrincipalID GlobalResourceGroupName GlobalResourceGroupLocation GlobalSubscriptionID KeyvaultPrefix MDMFrontendURL MDSDConfigVersion MDSDEnvironment RPImagePrefix RPNSGSourceAddressPrefixes RPParentDomainName SubscriptionResourceGroupName SubscriptionResourceGroupLocation SSHPublicKey StorageAccountDomain VMSize]"),
wantErr: "configuration has missing fields: ACRResourceID,RPVersionStorageAccountName,AdminAPICABundle,AdminAPIClientCertCommonName,ClusterParentDomainName,DatabaseAccountName,ExtraClusterKeyvaultAccessPolicies,ExtraCosmosDBIPs,ExtraPortalKeyvaultAccessPolicies,ExtraServiceKeyvaultAccessPolicies,FPServicePrincipalID,GlobalResourceGroupName,GlobalResourceGroupLocation,GlobalSubscriptionID,KeyvaultPrefix,MDMFrontendURL,MDSDConfigVersion,MDSDEnvironment,PortalAccessGroupIDs,PortalClientID,PortalElevatedGroupIDs,RPImagePrefix,RPNSGSourceAddressPrefixes,RPParentDomainName,SubscriptionResourceGroupName,SubscriptionResourceGroupLocation,SSHPublicKey,StorageAccountDomain,VMSize",
},
} {
valid := tt.config.validate()
if valid != tt.expect && valid.Error() != tt.expect.Error() {
t.Errorf("Expected %s but got %s", tt.name, valid.Error())
}
t.Run(tt.name, func(t *testing.T) {
err := tt.config.validate()
if err != nil && err.Error() != tt.wantErr ||
err == nil && tt.wantErr != "" {
t.Error(err)
}
})
}
}

Просмотреть файл

@ -180,6 +180,11 @@ func (d *deployer) configureDNS(ctx context.Context) error {
return err
}
portalPip, err := d.publicipaddresses.Get(ctx, d.config.ResourceGroupName, "portal-pip", "")
if err != nil {
return err
}
zone, err := d.zones.Get(ctx, d.config.ResourceGroupName, d.config.Location+"."+*d.config.Configuration.ClusterParentDomainName)
if err != nil {
return err
@ -199,6 +204,20 @@ func (d *deployer) configureDNS(ctx context.Context) error {
return err
}
_, err = d.globalrecordsets.CreateOrUpdate(ctx, *d.config.Configuration.GlobalResourceGroupName, *d.config.Configuration.RPParentDomainName, "admin."+d.config.Location, mgmtdns.A, mgmtdns.RecordSet{
RecordSetProperties: &mgmtdns.RecordSetProperties{
TTL: to.Int64Ptr(3600),
ARecords: &[]mgmtdns.ARecord{
{
Ipv4Address: portalPip.IPAddress,
},
},
},
}, "", "")
if err != nil {
return err
}
nsRecords := make([]mgmtdns.NsRecord, 0, len(*zone.NameServers))
for i := range *zone.NameServers {
nsRecords = append(nsRecords, mgmtdns.NsRecord{

Просмотреть файл

@ -6,6 +6,7 @@ package generator
// Deployment constants
const (
ClustersKeyvaultSuffix = "-cls"
PortalKeyvaultSuffix = "-por"
ServiceKeyvaultSuffix = "-svc"
)

Просмотреть файл

@ -507,7 +507,7 @@ func (g *generator) pevnet() *arm.Resource {
}
}
func (g *generator) pip() *arm.Resource {
func (g *generator) pip(name string) *arm.Resource {
return &arm.Resource{
Resource: &mgmtnetwork.PublicIPAddress{
Sku: &mgmtnetwork.PublicIPAddressSku{
@ -516,7 +516,7 @@ func (g *generator) pip() *arm.Resource {
PublicIPAddressPropertiesFormat: &mgmtnetwork.PublicIPAddressPropertiesFormat{
PublicIPAllocationMethod: mgmtnetwork.Static,
},
Name: to.StringPtr("rp-pip"),
Name: &name,
Type: to.StringPtr("Microsoft.Network/publicIPAddresses"),
Location: to.StringPtr("[resourceGroup().location]"),
},
@ -541,6 +541,14 @@ func (g *generator) lb() *arm.Resource {
},
Name: to.StringPtr("rp-frontend"),
},
{
FrontendIPConfigurationPropertiesFormat: &mgmtnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &mgmtnetwork.PublicIPAddress{
ID: to.StringPtr("[resourceId('Microsoft.Network/publicIPAddresses', 'portal-pip')]"),
},
},
Name: to.StringPtr("portal-frontend"),
},
},
BackendAddressPools: &[]mgmtnetwork.BackendAddressPool{
{
@ -566,6 +574,42 @@ func (g *generator) lb() *arm.Resource {
},
Name: to.StringPtr("rp-lbrule"),
},
{
LoadBalancingRulePropertiesFormat: &mgmtnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/frontendIPConfigurations', 'rp-lb', 'portal-frontend')]"),
},
BackendAddressPool: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/backendAddressPools', 'rp-lb', 'rp-backend')]"),
},
Probe: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/probes', 'rp-lb', 'portal-probe-https')]"),
},
Protocol: mgmtnetwork.TransportProtocolTCP,
LoadDistribution: mgmtnetwork.LoadDistributionDefault,
FrontendPort: to.Int32Ptr(443),
BackendPort: to.Int32Ptr(444),
},
Name: to.StringPtr("portal-lbrule"),
},
{
LoadBalancingRulePropertiesFormat: &mgmtnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/frontendIPConfigurations', 'rp-lb', 'portal-frontend')]"),
},
BackendAddressPool: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/backendAddressPools', 'rp-lb', 'rp-backend')]"),
},
Probe: &mgmtnetwork.SubResource{
ID: to.StringPtr("[resourceId('Microsoft.Network/loadBalancers/probes', 'rp-lb', 'portal-probe-ssh')]"),
},
Protocol: mgmtnetwork.TransportProtocolTCP,
LoadDistribution: mgmtnetwork.LoadDistributionDefault,
FrontendPort: to.Int32Ptr(22),
BackendPort: to.Int32Ptr(2222),
},
Name: to.StringPtr("portal-lbrule"),
},
},
Probes: &[]mgmtnetwork.Probe{
{
@ -577,6 +621,23 @@ func (g *generator) lb() *arm.Resource {
},
Name: to.StringPtr("rp-probe"),
},
{
ProbePropertiesFormat: &mgmtnetwork.ProbePropertiesFormat{
Protocol: mgmtnetwork.ProbeProtocolHTTPS,
Port: to.Int32Ptr(444),
NumberOfProbes: to.Int32Ptr(2),
RequestPath: to.StringPtr("/healthz/ready"),
},
Name: to.StringPtr("portal-probe-https"),
},
{
ProbePropertiesFormat: &mgmtnetwork.ProbePropertiesFormat{
Protocol: mgmtnetwork.ProbeProtocolTCP,
Port: to.Int32Ptr(2222),
NumberOfProbes: to.Int32Ptr(2),
},
Name: to.StringPtr("portal-probe-ssh"),
},
},
},
Name: to.StringPtr("rp-lb"),
@ -664,6 +725,9 @@ func (g *generator) vmss() *arm.Resource {
"mdsdEnvironment",
"acrResourceId",
"domainName",
"portalAccessGroupIds",
"portalClientId",
"portalElevatedGroupIds",
"rpImage",
"rpMode",
"adminApiClientCertCommonName",
@ -764,6 +828,8 @@ EOF
sysctl --system
firewall-cmd --add-port=443/tcp --permanent
firewall-cmd --add-port=444/tcp --permanent
firewall-cmd --add-port=2222/tcp --permanent
cat >/etc/td-agent-bit/td-agent-bit.conf <<'EOF'
[INPUT]
@ -970,9 +1036,49 @@ StartLimitInterval=0
WantedBy=multi-user.target
EOF
cat >/etc/sysconfig/aro-portal <<EOF
MDM_ACCOUNT=AzureRedHatOpenShiftRP
MDM_NAMESPACE=Portal
AZURE_PORTAL_CLIENT_ID='$PORTALCLIENTID'
AZURE_PORTAL_ACCESS_GROUP_IDS='$PORTALACCESSGROUPIDS'
AZURE_PORTAL_ELEVATED_GROUP_IDS='$PORTALELEVATEDGROUPIDS'
RPIMAGE='$RPIMAGE'
RP_MODE='$RPMODE'
EOF
cat >/etc/systemd/system/aro-portal.service <<'EOF'
[Unit]
After=docker.service
Requires=docker.service
StartLimitInterval=0
[Service]
EnvironmentFile=/etc/sysconfig/aro-portal
ExecStartPre=-/usr/bin/docker rm -f %N
ExecStart=/usr/bin/docker run \
--hostname %H \
--name %N \
--rm \
-e ADMIN_API_CLIENT_CERT_COMMON_NAME \
-e MDM_ACCOUNT \
-e MDM_NAMESPACE \
-e RP_MODE \
-p 444:8444 \
-p 2222:2222 \
-v /run/systemd/journal:/run/systemd/journal \
-v /var/etw:/var/etw:z \
$RPIMAGE \
portal
Restart=always
RestartSec=1
[Install]
WantedBy=multi-user.target
EOF
chcon -R system_u:object_r:var_log_t:s0 /var/opt/microsoft/linuxmonagent
for service in aro-monitor aro-rp auoms azsecd azsecmond mdsd mdm chronyd td-agent-bit; do
for service in aro-monitor aro-portal aro-rp auoms azsecd azsecmond mdsd mdm chronyd td-agent-bit; do
systemctl enable $service.service
done
@ -1155,6 +1261,20 @@ func (g *generator) clusterKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyE
}
}
func (g *generator) portalKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyEntry {
return []mgmtkeyvault.AccessPolicyEntry{
{
TenantID: &tenantUUIDHack,
ObjectID: to.StringPtr("[parameters('rpServicePrincipalId')]"),
Permissions: &mgmtkeyvault.Permissions{
Secrets: &[]mgmtkeyvault.SecretPermissions{
mgmtkeyvault.SecretPermissionsGet,
},
},
},
}
}
func (g *generator) serviceKeyvaultAccessPolicies() []mgmtkeyvault.AccessPolicyEntry {
return []mgmtkeyvault.AccessPolicyEntry{
{
@ -1207,6 +1327,50 @@ func (g *generator) clustersKeyvault() *arm.Resource {
}
}
func (g *generator) portalKeyvault() *arm.Resource {
vault := &mgmtkeyvault.Vault{
Properties: &mgmtkeyvault.VaultProperties{
EnableSoftDelete: to.BoolPtr(true),
TenantID: &tenantUUIDHack,
Sku: &mgmtkeyvault.Sku{
Name: mgmtkeyvault.Standard,
Family: to.StringPtr("A"),
},
AccessPolicies: &[]mgmtkeyvault.AccessPolicyEntry{},
},
Name: to.StringPtr("[concat(parameters('keyvaultPrefix'), '" + PortalKeyvaultSuffix + "')]"),
Type: to.StringPtr("Microsoft.KeyVault/vaults"),
Location: to.StringPtr("[resourceGroup().location]"),
}
if !g.production {
*vault.Properties.AccessPolicies = append(g.portalKeyvaultAccessPolicies(),
mgmtkeyvault.AccessPolicyEntry{
TenantID: &tenantUUIDHack,
ObjectID: to.StringPtr("[parameters('adminObjectId')]"),
Permissions: &mgmtkeyvault.Permissions{
Certificates: &[]mgmtkeyvault.CertificatePermissions{
mgmtkeyvault.Delete,
mgmtkeyvault.Get,
mgmtkeyvault.Import,
mgmtkeyvault.List,
},
Secrets: &[]mgmtkeyvault.SecretPermissions{
mgmtkeyvault.SecretPermissionsSet,
mgmtkeyvault.SecretPermissionsList,
},
},
},
)
}
return &arm.Resource{
Resource: vault,
Condition: g.conditionStanza("fullDeploy"),
APIVersion: azureclient.APIVersion("Microsoft.KeyVault"),
}
}
func (g *generator) serviceKeyvault() *arm.Resource {
vault := &mgmtkeyvault.Vault{
Properties: &mgmtkeyvault.VaultProperties{
@ -1304,6 +1468,36 @@ func (g *generator) cosmosdb() []*arm.Resource {
}
func (g *generator) database(databaseName string, addDependsOn bool) []*arm.Resource {
portal := &arm.Resource{
Resource: &mgmtdocumentdb.SQLContainerCreateUpdateParameters{
SQLContainerCreateUpdateProperties: &mgmtdocumentdb.SQLContainerCreateUpdateProperties{
Resource: &mgmtdocumentdb.SQLContainerResource{
ID: to.StringPtr("Portal"),
PartitionKey: &mgmtdocumentdb.ContainerPartitionKey{
Paths: &[]string{
"/id",
},
Kind: mgmtdocumentdb.PartitionKindHash,
},
DefaultTTL: to.Int32Ptr(-1),
},
Options: map[string]*string{},
},
Name: to.StringPtr("[concat(parameters('databaseAccountName'), '/', " + databaseName + ", '/Portal')]"),
Type: to.StringPtr("Microsoft.DocumentDB/databaseAccounts/sqlDatabases/containers"),
Location: to.StringPtr("[resourceGroup().location]"),
},
Condition: g.conditionStanza("fullDeploy"),
APIVersion: azureclient.APIVersion("Microsoft.DocumentDB"),
DependsOn: []string{
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), " + databaseName + ")]",
},
}
if g.production {
portal.Resource.(*mgmtdocumentdb.SQLContainerCreateUpdateParameters).SQLContainerCreateUpdateProperties.Options["throughput"] = to.StringPtr("400")
}
rs := []*arm.Resource{
{
Resource: &mgmtdocumentdb.SQLDatabaseCreateUpdateParameters{
@ -1439,6 +1633,7 @@ func (g *generator) database(databaseName string, addDependsOn bool) []*arm.Reso
"[resourceId('Microsoft.DocumentDB/databaseAccounts/sqlDatabases', parameters('databaseAccountName'), " + databaseName + ")]",
},
},
portal,
{
Resource: &mgmtdocumentdb.SQLContainerCreateUpdateParameters{
SQLContainerCreateUpdateProperties: &mgmtdocumentdb.SQLContainerCreateUpdateProperties{

Просмотреть файл

@ -86,6 +86,9 @@ func (g *generator) rpTemplate() *arm.Template {
"mdmFrontendUrl",
"mdsdConfigVersion",
"mdsdEnvironment",
"portalAccessGroupIds",
"portalClientId",
"portalElevatedGroupIds",
"rpImage",
"rpMode",
"sshPublicKey",
@ -111,7 +114,11 @@ func (g *generator) rpTemplate() *arm.Template {
}
if g.production {
t.Resources = append(t.Resources, g.pip(), g.lb(), g.vmss(),
t.Resources = append(t.Resources,
g.pip("rp-pip"),
g.pip("portal-pip"),
g.lb(),
g.vmss(),
g.storageAccount(),
g.lbAlert(30.0, 2, "rp-availability-alert", "PT5M", "PT15M", "DipAvailability"), // triggers on all 3 RPs being down for 10min, can't be >=0.3 due to deploys going down to 32% at times.
g.lbAlert(67.0, 3, "rp-degraded-alert", "PT15M", "PT6H", "DipAvailability"), // 1/3 backend down for 1h or 2/3 down for 3h in the last 6h
@ -270,6 +277,7 @@ func (g *generator) preDeployTemplate() *arm.Template {
if g.production {
t.Variables = map[string]interface{}{
"clusterKeyvaultAccessPolicies": g.clusterKeyvaultAccessPolicies(),
"portalKeyvaultAccessPolicies": g.portalKeyvaultAccessPolicies(),
"serviceKeyvaultAccessPolicies": g.serviceKeyvaultAccessPolicies(),
}
}
@ -284,6 +292,7 @@ func (g *generator) preDeployTemplate() *arm.Template {
params = append(params,
"deployNSGs",
"extraClusterKeyvaultAccessPolicies",
"extraPortalKeyvaultAccessPolicies",
"extraServiceKeyvaultAccessPolicies",
"fullDeploy",
"rpNsgSourceAddressPrefixes",
@ -300,7 +309,9 @@ func (g *generator) preDeployTemplate() *arm.Template {
case "deployNSGs":
p.Type = "bool"
p.DefaultValue = false
case "extraClusterKeyvaultAccessPolicies", "extraServiceKeyvaultAccessPolicies":
case "extraClusterKeyvaultAccessPolicies",
"extraPortalKeyvaultAccessPolicies",
"extraServiceKeyvaultAccessPolicies":
p.Type = "array"
p.DefaultValue = []interface{}{}
case "fullDeploy":
@ -318,9 +329,10 @@ func (g *generator) preDeployTemplate() *arm.Template {
t.Resources = append(t.Resources,
g.securityGroupRP(),
g.securityGroupPE(),
// clustersKeyvault must preceed serviceKeyvault due to terrible
// bytes.Replace in templateFixup
// clustersKeyvault, portalKeyvault and serviceKeyvault must be in this
// order due to terrible bytes.Replace in templateFixup
g.clustersKeyvault(),
g.portalKeyvault(),
g.serviceKeyvault(),
)
@ -409,6 +421,7 @@ func (g *generator) templateFixup(t *arm.Template) ([]byte, error) {
b = bytes.ReplaceAll(b, []byte(`"capacity": 1337`), []byte(`"capacity": "[int(parameters('ciCapacity'))]"`))
if g.production {
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('clusterKeyvaultAccessPolicies'), parameters('extraClusterKeyvaultAccessPolicies'))]"`), 1)
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('portalKeyvaultAccessPolicies'), parameters('extraPortalKeyvaultAccessPolicies'))]"`), 1)
b = bytes.Replace(b, []byte(`"accessPolicies": []`), []byte(`"accessPolicies": "[concat(variables('serviceKeyvaultAccessPolicies'), parameters('extraServiceKeyvaultAccessPolicies'))]"`), 1)
b = bytes.Replace(b, []byte(`"sourceAddressPrefixes": []`), []byte(`"sourceAddressPrefixes": "[parameters('rpNsgSourceAddressPrefixes')]"`), 1)
}

Просмотреть файл

@ -6,6 +6,8 @@ package deploy
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"path/filepath"
@ -307,7 +309,17 @@ func (d *deployer) configureServiceSecrets(ctx context.Context) error {
return err
}
return d.ensureSecret(ctx, secrets, env.FrontendEncryptionSecretName)
err = d.ensureSecret(ctx, secrets, env.FrontendEncryptionSecretName)
if err != nil {
return err
}
err = d.ensureSecret(ctx, secrets, env.PortalServerSessionKeySecretName)
if err != nil {
return err
}
return d.ensureSecretKey(ctx, secrets, env.PortalServerSSHKeySecretName)
}
func (d *deployer) ensureSecret(ctx context.Context, existingSecrets []keyvault.SecretItem, secretName string) error {
@ -328,3 +340,21 @@ func (d *deployer) ensureSecret(ctx context.Context, existingSecrets []keyvault.
Value: to.StringPtr(base64.StdEncoding.EncodeToString(key)),
})
}
func (d *deployer) ensureSecretKey(ctx context.Context, existingSecrets []keyvault.SecretItem, secretName string) error {
for _, secret := range existingSecrets {
if filepath.Base(*secret.ID) == secretName {
return nil
}
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
d.log.Infof("setting %s", secretName)
return d.keyvault.SetSecret(ctx, secretName, keyvault.SecretSetParameters{
Value: to.StringPtr(base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(key))),
})
}

4
pkg/env/env.go поставляемый
Просмотреть файл

@ -26,6 +26,10 @@ const (
FrontendEncryptionSecretName = "fe-encryption-key"
RPLoggingSecretName = "rp-mdsd"
RPMonitoringSecretName = "rp-mdm"
PortalServerSecretName = "portal-server"
PortalServerClientSecretName = "portal-client"
PortalServerSessionKeySecretName = "portal-session-key"
PortalServerSSHKeySecretName = "portal-sshkey"
)
type Interface interface {

2
pkg/env/generate.go поставляемый
Просмотреть файл

@ -4,5 +4,5 @@ package env
// Licensed under the Apache License 2.0.
//go:generate rm -rf ../util/mocks/$GOPACKAGE
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Interface
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Core,Interface
//go:generate go run ../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../util/mocks/$GOPACKAGE/$GOPACKAGE.go

Просмотреть файл

@ -0,0 +1,93 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link rel="stylesheet" href="lib/bootstrap-4.5.2.min.css">
<title>ARO SRE portal ({{ .location }})</title>
</head>
<body>
<div class="navbar navbar-light bg-light shadow-sm">
<div class="navbar-brand">
<strong>ARO SRE portal ({{ .location }})</strong>
</div>
<button class="btn btn-secondary" id="btnLogout">Logout</button>
</div>
<div class="container py-4">
<div class="form-group">
<label for="selResourceId">Cluster:</label>
<div class="col-sm-10">
<select class="form-control form-control-sm" id="selResourceId">
</select>
</div>
</div>
<div class="form-group">
<label for="selMaster">Master:</label>
<div class="col-sm-10">
<select class="form-control form-control-sm" id="selMaster">
<option value="0">master-0</option>
<option value="1">master-1</option>
<option value="2">master-2</option>
</select>
</div>
</div>
<button class="btn btn-secondary" id="btnPrometheus">Prometheus</button>
<button class="btn btn-secondary" id="btnKubeconfig">Kubeconfig</button>
<button class="btn btn-secondary" id="btnSSH">SSH</button>
<div class="py-4" id="divAlerts"></div>
</div>
<template id="tmplSSHAlert">
<div class="alert alert-primary alert-dismissible fade show" role="alert">
<div>
<button class="btn btn-secondary copy-button">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white" width="18px" height="18px">
<path d="M0 0h24v24H0z" fill="none"/>
<path d="M16 1H4c-1.1 0-2 .9-2 2v14h2V3h12V1zm3 4H8c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h11c1.1 0 2-.9 2-2V7c0-1.1-.9-2-2-2zm0 16H8V7h11v14z"/>
</svg>
</button>
<span data-copy="command">Command: <code></code></span>
</div>
<div class="py-2">
<button class="btn btn-secondary copy-button">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="white" width="18px" height="18px">
<path d="M0 0h24v24H0z" fill="none"/>
<path d="M16 1H4c-1.1 0-2 .9-2 2v14h2V3h12V1zm3 4H8c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h11c1.1 0 2-.9 2-2V7c0-1.1-.9-2-2-2zm0 16H8V7h11v14z"/>
</svg>
</button>
<span data-copy="password">Password: <code></code></span>
</div>
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
<span aria-hidden="true">&times;</span>
</button>
</div>
</template>
<template id="tmplSSHAlertError">
<div class="alert alert-danger alert-dismissible fade show" role="alert">
<span data-copy="error"></span>
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
<span aria-hidden="true">&times;</span>
</button>
</div>
</template>
{{ .csrfField }}
<script src="lib/jquery-3.5.1.min.js"></script>
<script src="lib/bootstrap-4.5.2.min.js"></script>
<script src="index.js"></script>
</body>
</html>

Просмотреть файл

@ -0,0 +1,89 @@
$.extend({
redirect: function (location, args) {
var form = $("<form method='POST' style='display: none;'></form>");
form.attr("action", location);
$.each(args || {}, function (key, value) {
var input = $("<input name='hidden'></input>");
input.attr("name", key);
input.attr("value", value);
form.append(input);
});
form.append($("input[name='gorilla.csrf.Token']").first());
form.appendTo("body").submit();
}
});
$(document).ready(function () {
$.ajax({
url: "/api/clusters",
success: function (clusters) {
$.each(clusters, function (i, cluster) {
$("#selResourceId").append($("<option>").text(cluster));
});
},
dataType: "json",
});
$("#btnLogout").click(function () {
$.redirect("/api/logout");
});
$("#btnKubeconfig").click(function () {
$.redirect($("#selResourceId").val() + "/kubeconfig/new");
});
$("#btnPrometheus").click(function () {
window.location = $("#selResourceId").val() + "/prometheus";
});
$("#btnSSH").click(function () {
$.ajax({
method: "POST",
url: $("#selResourceId").val() + "/ssh/new",
headers: {
"X-CSRF-Token": $("input[name='gorilla.csrf.Token']").val(),
},
contentType: "application/json",
data: JSON.stringify({
"master": parseInt($("#selMaster").val()),
}),
success: function (reply) {
if (reply["error"]) {
var template = $("#tmplSSHAlertError").html();
var alert = $(template);
alert.find("span[data-copy='error']").text(reply["error"]);
$("#divAlerts").html(alert);
return;
}
var template = $("#tmplSSHAlert").html();
var alert = $(template);
alert.find("span[data-copy='command'] > code").text(reply["command"]);
alert.find("span[data-copy='command']").attr("data-copy", reply["command"]);
alert.find("span[data-copy='password'] > code").text("********");
alert.find("span[data-copy='password']").attr("data-copy", reply["password"]);
$("#divAlerts").html(alert);
$('.copy-button').click(function () {
var textarea = $("<textarea class='style: hidden;' id='textarea'></textarea>");
textarea.text($(this).next().attr("data-copy"));
textarea.appendTo("body");
textarea = document.getElementById("textarea")
textarea.select();
textarea.setSelectionRange(0, textarea.value.length + 1);
document.execCommand('copy');
document.body.removeChild(textarea)
});
},
dataType: "json",
});
});
});

7
pkg/portal/assets/lib/bootstrap-4.5.2.min.css поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

7
pkg/portal/assets/lib/bootstrap-4.5.2.min.js поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

2
pkg/portal/assets/lib/jquery-3.5.1.min.js поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

338
pkg/portal/bindata.go Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

7
pkg/portal/generate.go Normal file
Просмотреть файл

@ -0,0 +1,7 @@
package portal
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
//go:generate go run ../../vendor/github.com/go-bindata/go-bindata/go-bindata -nometadata -pkg $GOPACKAGE -prefix assets assets/...
//go:generate gofmt -s -l -w bindata.go

Просмотреть файл

@ -0,0 +1,183 @@
package kubeconfig
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"log"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/gorilla/mux"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/api/validate"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/portal/util/clientcache"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
)
const (
kubeconfigNewTimeout = 6 * time.Hour
)
type kubeconfig struct {
log *logrus.Entry
baseAccessLog *logrus.Entry
servingCert *x509.Certificate
elevatedGroupIDs []string
dbOpenShiftClusters database.OpenShiftClusters
dbPortal database.Portal
dialer proxy.Dialer
clientCache clientcache.ClientCache
newToken func() string
}
func New(baseLog *logrus.Entry,
baseAccessLog *logrus.Entry,
servingCert *x509.Certificate,
elevatedGroupIDs []string,
dbOpenShiftClusters database.OpenShiftClusters,
dbPortal database.Portal,
dialer proxy.Dialer,
aadAuthenticatedRouter,
unauthenticatedRouter *mux.Router) *kubeconfig {
k := &kubeconfig{
log: baseLog,
baseAccessLog: baseAccessLog,
servingCert: servingCert,
elevatedGroupIDs: elevatedGroupIDs,
dbOpenShiftClusters: dbOpenShiftClusters,
dbPortal: dbPortal,
dialer: dialer,
clientCache: clientcache.New(time.Hour),
newToken: func() string { return uuid.NewV4().String() },
}
rp := &httputil.ReverseProxy{
Director: k.director,
Transport: roundtripper.RoundTripperFunc(k.roundTripper),
ErrorLog: log.New(k.log.Writer(), "", 0),
}
aadAuthenticatedRouter.NewRoute().Methods(http.MethodPost).Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/kubeconfig/new").HandlerFunc(k.new)
bearerAuthenticatedRouter := unauthenticatedRouter.NewRoute().Subrouter()
bearerAuthenticatedRouter.Use(middleware.Bearer(k.dbPortal))
bearerAuthenticatedRouter.Use(middleware.Log(k.baseAccessLog))
bearerAuthenticatedRouter.PathPrefix("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/kubeconfig/proxy/").Handler(rp)
return k
}
// new creates a new PortalDocument allowing kubeconfig access to a cluster for
// 6 hours and returns a kubeconfig with the temporary credentials
func (k *kubeconfig) new(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
if !validate.RxClusterID.MatchString(resourceID) {
http.Error(w, fmt.Sprintf("invalid resourceId %q", resourceID), http.StatusBadRequest)
return
}
elevated := middleware.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))
token := k.newToken()
portalDoc := &api.PortalDocument{
ID: token,
TTL: int(kubeconfigNewTimeout / time.Second),
Portal: &api.Portal{
Username: ctx.Value(middleware.ContextKeyUsername).(string),
ID: resourceID,
Kubeconfig: &api.Kubeconfig{
Elevated: elevated,
},
},
}
_, err := k.dbPortal.Create(ctx, portalDoc)
if err != nil {
k.internalServerError(w, err)
return
}
b, err := k.makeKubeconfig("https://"+r.Host+resourceID+"/kubeconfig/proxy", token)
if err != nil {
k.internalServerError(w, err)
return
}
filename := strings.Split(r.URL.Path, "/")[8]
if elevated {
filename += "-elevated"
}
w.Header().Add("Content-Type", "application/json")
w.Header().Add("Content-Disposition", `attachment; filename="`+filename+`.kubeconfig"`)
_, _ = w.Write(b)
}
func (k *kubeconfig) internalServerError(w http.ResponseWriter, err error) {
k.log.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
func (k *kubeconfig) makeKubeconfig(server, token string) ([]byte, error) {
return json.MarshalIndent(&v1.Config{
APIVersion: "v1",
Kind: "Config",
Clusters: []v1.NamedCluster{
{
Name: "cluster",
Cluster: v1.Cluster{
Server: server,
CertificateAuthorityData: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: k.servingCert.Raw,
}),
},
},
},
AuthInfos: []v1.NamedAuthInfo{
{
Name: "user",
AuthInfo: v1.AuthInfo{
Token: token,
},
},
},
Contexts: []v1.NamedContext{
{
Name: "context",
Context: v1.Context{
Cluster: "cluster",
Namespace: "default",
AuthInfo: "user",
},
},
},
CurrentContext: "context",
}, "", " ")
}

Просмотреть файл

@ -0,0 +1,189 @@
package kubeconfig
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
testdatabase "github.com/Azure/ARO-RP/test/database"
)
func TestNew(t *testing.T) {
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
elevatedGroupIDs := []string{"10000000-0000-0000-0000-000000000000"}
username := "username"
password := "password"
servingCert := &x509.Certificate{}
for _, tt := range []struct {
name string
r func(*http.Request)
elevated bool
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakePortalDocumentClient)
wantStatusCode int
wantHeaders http.Header
wantBody string
}{
{
name: "success - not elevated",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: password,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{},
},
}
checker.AddPortalDocuments(portalDocument)
},
wantStatusCode: http.StatusOK,
wantHeaders: http.Header{
"Content-Disposition": []string{`attachment; filename="cluster.kubeconfig"`},
},
wantBody: "{\n \"kind\": \"Config\",\n \"apiVersion\": \"v1\",\n \"preferences\": {},\n \"clusters\": [\n {\n \"name\": \"cluster\",\n \"cluster\": {\n \"server\": \"https://localhost:8444/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy\",\n \"certificate-authority-data\": \"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K\"\n }\n }\n ],\n \"users\": [\n {\n \"name\": \"user\",\n \"user\": {\n \"token\": \"password\"\n }\n }\n ],\n \"contexts\": [\n {\n \"name\": \"context\",\n \"context\": {\n \"cluster\": \"cluster\",\n \"user\": \"user\",\n \"namespace\": \"default\"\n }\n }\n ],\n \"current-context\": \"context\"\n}",
},
{
name: "success - elevated",
elevated: true,
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: password,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{
Elevated: true,
},
},
}
checker.AddPortalDocuments(portalDocument)
},
wantStatusCode: http.StatusOK,
wantHeaders: http.Header{
"Content-Disposition": []string{`attachment; filename="cluster-elevated.kubeconfig"`},
},
wantBody: "{\n \"kind\": \"Config\",\n \"apiVersion\": \"v1\",\n \"preferences\": {},\n \"clusters\": [\n {\n \"name\": \"cluster\",\n \"cluster\": {\n \"server\": \"https://localhost:8444/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy\",\n \"certificate-authority-data\": \"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K\"\n }\n }\n ],\n \"users\": [\n {\n \"name\": \"user\",\n \"user\": {\n \"token\": \"password\"\n }\n }\n ],\n \"contexts\": [\n {\n \"name\": \"context\",\n \"context\": {\n \"cluster\": \"cluster\",\n \"user\": \"user\",\n \"namespace\": \"default\"\n }\n }\n ],\n \"current-context\": \"context\"\n}",
},
{
name: "bad path",
r: func(r *http.Request) {
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/new"
},
wantStatusCode: http.StatusBadRequest,
wantBody: "invalid resourceId \"/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster\"\n",
},
{
name: "sad database",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
portalClient.SetError(fmt.Errorf("sad"))
},
wantStatusCode: http.StatusInternalServerError,
wantBody: "Internal Server Error\n",
},
} {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
dbPortal, portalClient := testdatabase.NewFakePortal()
fixture := testdatabase.NewFixture().
WithPortal(dbPortal)
checker := testdatabase.NewChecker()
if tt.fixtureChecker != nil {
tt.fixtureChecker(fixture, checker, portalClient)
}
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
ctx = context.WithValue(ctx, middleware.ContextKeyUsername, username)
if tt.elevated {
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, elevatedGroupIDs)
} else {
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, []string(nil))
}
r, err := http.NewRequestWithContext(ctx, http.MethodPost,
"https://localhost:8444"+resourceID+"/kubeconfig/new", nil)
if err != nil {
panic(err)
}
r.Header.Set("Content-Type", "application/json")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
aadAuthenticatedRouter := &mux.Router{}
k := New(logrus.NewEntry(logrus.StandardLogger()), nil, servingCert, elevatedGroupIDs, nil, dbPortal, nil, aadAuthenticatedRouter, &mux.Router{})
k.newToken = func() string { return password }
if tt.r != nil {
tt.r(r)
}
w := responsewriter.New(r)
aadAuthenticatedRouter.ServeHTTP(w, r)
portalClient.SetError(nil)
for _, err = range checker.CheckPortals(portalClient) {
t.Error(err)
}
resp := w.Response()
if resp.StatusCode != tt.wantStatusCode {
t.Error(resp.StatusCode)
}
for k, v := range tt.wantHeaders {
if !reflect.DeepEqual(resp.Header[k], v) {
t.Errorf(k, resp.Header[k], v)
}
}
wantContentType := "application/json"
if resp.StatusCode != http.StatusOK {
wantContentType = "text/plain; charset=utf-8"
}
if resp.Header.Get("Content-Type") != wantContentType {
t.Error(resp.Header.Get("Content-Type"))
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != tt.wantBody {
t.Errorf("%q", string(b))
}
})
}
}

Просмотреть файл

@ -0,0 +1,218 @@
package kubeconfig
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net/http"
"strings"
"time"
"github.com/ghodss/yaml"
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/api/validate"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
"github.com/Azure/ARO-RP/pkg/util/pem"
"github.com/Azure/ARO-RP/pkg/util/restconfig"
)
const (
kubeconfigTimeout = time.Hour
)
type contextKey int
const (
contextKeyClient contextKey = iota
contextKeyResponse
)
// director is called by the ReverseProxy. It converts an incoming request into
// the one that'll go out to the API server. It also resolves an HTTP client
// that will be able to make the ongoing request.
//
// Unfortunately the signature of httputil.ReverseProxy.Director does not allow
// us to return values. We get around this limitation slightly naughtily by
// storing return information in the request context.
func (k *kubeconfig) director(r *http.Request) {
ctx := r.Context()
portalDoc, _ := ctx.Value(middleware.ContextKeyPortalDoc).(*api.PortalDocument)
if portalDoc == nil || portalDoc.Portal.Kubeconfig == nil {
k.error(r, http.StatusForbidden, nil)
return
}
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
if !validate.RxClusterID.MatchString(resourceID) ||
!strings.EqualFold(resourceID, portalDoc.Portal.ID) {
k.error(r, http.StatusBadRequest, nil)
return
}
key := struct {
resourceID string
elevated bool
}{
resourceID: portalDoc.Portal.ID,
elevated: portalDoc.Portal.Kubeconfig.Elevated,
}
cli := k.clientCache.Get(key)
if cli == nil {
var err error
cli, err = k.cli(ctx, key.resourceID, key.elevated)
if err != nil {
k.error(r, http.StatusInternalServerError, err)
return
}
k.clientCache.Put(key, cli)
}
r.RequestURI = ""
r.URL.Scheme = "https"
r.URL.Host = "kubernetes:6443"
r.URL.Path = "/" + strings.Join(strings.Split(r.URL.Path, "/")[11:], "/")
r.Header.Del("Authorization")
r.Host = r.URL.Host
// http.Request.WithContext returns a copy of the original Request with the
// new context, but we have no way to return it, so we overwrite our
// existing request.
*r = *r.WithContext(context.WithValue(ctx, contextKeyClient, cli))
}
// cli returns an appropriately configured HTTP client for forwarding the
// incoming request to a cluster
func (k *kubeconfig) cli(ctx context.Context, resourceID string, elevated bool) (*http.Client, error) {
openShiftDoc, err := k.dbOpenShiftClusters.Get(ctx, resourceID)
if err != nil {
return nil, err
}
kc := openShiftDoc.OpenShiftCluster.Properties.AROSREKubeconfig
if elevated {
kc = openShiftDoc.OpenShiftCluster.Properties.AROServiceKubeconfig
}
var kubeconfig *v1.Config
err = yaml.Unmarshal(kc, &kubeconfig)
if err != nil {
return nil, err
}
var b []byte
b = append(b, kubeconfig.AuthInfos[0].AuthInfo.ClientKeyData...)
b = append(b, kubeconfig.AuthInfos[0].AuthInfo.ClientCertificateData...)
clientKey, clientCerts, err := pem.Parse(b)
if err != nil {
return nil, err
}
_, caCerts, err := pem.Parse(kubeconfig.Clusters[0].Cluster.CertificateAuthorityData)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
for _, caCert := range caCerts {
pool.AddCert(caCert)
}
return &http.Client{
Transport: &http.Transport{
DialContext: restconfig.DialContext(k.dialer, openShiftDoc.OpenShiftCluster),
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{
clientCerts[0].Raw,
},
PrivateKey: clientKey,
},
},
RootCAs: pool,
},
},
}, nil
}
// roundTripper is called by ReverseProxy to make the onward request happen. We
// check if we had an error earlier and return that if we did. Otherwise we dig
// out the client and call it.
func (k *kubeconfig) roundTripper(r *http.Request) (*http.Response, error) {
if resp, ok := r.Context().Value(contextKeyResponse).(*http.Response); ok {
return resp, nil
}
cli := r.Context().Value(contextKeyClient).(*http.Client)
resp, err := cli.Do(r)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusSwitchingProtocols {
resp.Body = newCancelBody(resp.Body.(io.ReadWriteCloser), kubeconfigTimeout)
}
return resp, err
}
func (k *kubeconfig) error(r *http.Request, statusCode int, err error) {
if err != nil {
k.log.Warn(err)
}
w := responsewriter.New(r)
http.Error(w, http.StatusText(statusCode), statusCode)
*r = *r.WithContext(context.WithValue(r.Context(), contextKeyResponse, w.Response()))
}
// cancelBody is a workaround for the fact that http timeouts are incompatible
// with hijacked connections (https://github.com/golang/go/issues/31391):
// net/http.cancelTimerBody does not implement Writer.
type cancelBody struct {
io.ReadWriteCloser
t *time.Timer
c chan struct{}
}
func (b *cancelBody) wait() {
select {
case <-b.t.C:
b.ReadWriteCloser.Close()
case <-b.c:
b.t.Stop()
}
}
func (b *cancelBody) Close() error {
select {
case b.c <- struct{}{}:
default:
}
return b.ReadWriteCloser.Close()
}
func newCancelBody(rwc io.ReadWriteCloser, d time.Duration) io.ReadWriteCloser {
b := &cancelBody{
ReadWriteCloser: rwc,
t: time.NewTimer(d),
c: make(chan struct{}),
}
go b.wait()
return b
}

Просмотреть файл

@ -0,0 +1,403 @@
package kubeconfig
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
testdatabase "github.com/Azure/ARO-RP/test/database"
"github.com/Azure/ARO-RP/test/util/listener"
)
// fakeServer returns a test listener for an HTTPS server which validates its
// client and echos back the request it received
func fakeServer(cacerts []*x509.Certificate, serverkey *rsa.PrivateKey, servercerts []*x509.Certificate) *listener.Listener {
l := listener.NewListener()
pool := x509.NewCertPool()
pool.AddCert(cacerts[0])
go func() {
_ = http.Serve(tls.NewListener(l, &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{
servercerts[0].Raw,
},
PrivateKey: serverkey,
},
},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
}), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Header.Add("X-Authenticated-Name", r.TLS.PeerCertificates[0].Subject.CommonName)
b, _ := httputil.DumpRequest(r, true)
_, _ = w.Write(b)
}))
}()
return l
}
func testKubeconfig(cacerts []*x509.Certificate, clientkey *rsa.PrivateKey, clientcerts []*x509.Certificate) ([]byte, error) {
kc := &v1.Config{
Clusters: []v1.NamedCluster{
{},
},
AuthInfos: []v1.NamedAuthInfo{
{},
},
}
var err error
kc.AuthInfos[0].AuthInfo.ClientKeyData, err = utiltls.PrivateKeyAsBytes(clientkey)
if err != nil {
return nil, err
}
kc.AuthInfos[0].AuthInfo.ClientCertificateData, err = utiltls.CertAsBytes(clientcerts[0])
if err != nil {
return nil, err
}
kc.Clusters[0].Cluster.CertificateAuthorityData, err = utiltls.CertAsBytes(cacerts[0])
if err != nil {
return nil, err
}
return json.Marshal(kc)
}
func TestProxy(t *testing.T) {
ctx := context.Background()
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
username := "username"
token := "00000000-0000-0000-0000-000000000000"
privateEndpointIP := "1.2.3.4"
cakey, cacerts, err := utiltls.GenerateKeyAndCertificate("ca", nil, nil, true, false)
if err != nil {
t.Fatal(err)
}
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("kubernetes", cakey, cacerts[0], false, false)
if err != nil {
t.Fatal(err)
}
sreClientkey, sreClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-sre", cakey, cacerts[0], false, true)
if err != nil {
t.Fatal(err)
}
sreKubeconfig, err := testKubeconfig(cacerts, sreClientkey, sreClientcerts)
if err != nil {
t.Fatal(err)
}
serviceClientkey, serviceClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-service", cakey, cacerts[0], false, true)
if err != nil {
t.Fatal(err)
}
serviceKubeconfig, err := testKubeconfig(cacerts, serviceClientkey, serviceClientcerts)
if err != nil {
t.Fatal(err)
}
l := fakeServer(cacerts, serverkey, servercerts)
defer l.Close()
for _, tt := range []struct {
name string
r func(*http.Request)
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient, *cosmosdb.FakePortalDocumentClient)
mocks func(*mock_proxy.MockDialer)
wantStatusCode int
wantBody string
}{
{
name: "success - elevated",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{
Elevated: true,
},
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
openShiftClusterDocument := &api.OpenShiftClusterDocument{
ID: resourceID,
Key: resourceID,
OpenShiftCluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
NetworkProfile: api.NetworkProfile{
PrivateEndpointIP: privateEndpointIP,
},
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
AROSREKubeconfig: api.SecureBytes(sreKubeconfig),
},
},
}
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
},
mocks: func(dialer *mock_proxy.MockDialer) {
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
},
wantStatusCode: http.StatusOK,
wantBody: "GET /test HTTP/1.1\r\nHost: kubernetes:6443\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\nX-Authenticated-Name: system:aro-service\r\n\r\n",
},
{
name: "success - not elevated",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{},
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
openShiftClusterDocument := &api.OpenShiftClusterDocument{
ID: resourceID,
Key: resourceID,
OpenShiftCluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
NetworkProfile: api.NetworkProfile{
PrivateEndpointIP: privateEndpointIP,
},
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
AROSREKubeconfig: api.SecureBytes(sreKubeconfig),
},
},
}
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
},
mocks: func(dialer *mock_proxy.MockDialer) {
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
},
wantStatusCode: http.StatusOK,
wantBody: "GET /test HTTP/1.1\r\nHost: kubernetes:6443\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\nX-Authenticated-Name: system:aro-sre\r\n\r\n",
},
{
name: "no auth",
r: func(r *http.Request) {
r.Header.Del("Authorization")
},
wantStatusCode: http.StatusForbidden,
wantBody: "Forbidden\n",
},
{
name: "bad auth, not uuid",
r: func(r *http.Request) {
r.Header.Set("Authorization", "Bearer bad")
},
wantStatusCode: http.StatusForbidden,
wantBody: "Forbidden\n",
},
{
name: "bad auth",
wantStatusCode: http.StatusForbidden,
wantBody: "Forbidden\n",
},
{
name: "not kubeconfig record",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
},
wantStatusCode: http.StatusForbidden,
wantBody: "Forbidden\n",
},
{
name: "bad path",
r: func(r *http.Request) {
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/kubeconfig/proxy/test"
},
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{},
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
},
wantStatusCode: http.StatusBadRequest,
wantBody: "Bad Request\n",
},
{
name: "mismatched path",
r: func(r *http.Request) {
r.URL.Path = "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/mismatch/kubeconfig/proxy/test"
},
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{},
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
},
wantStatusCode: http.StatusBadRequest,
wantBody: "Bad Request\n",
},
{
name: "sad portal database",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalClient.SetError(fmt.Errorf("sad"))
},
wantStatusCode: http.StatusForbidden,
wantBody: "Forbidden\n",
},
{
name: "sad openshift database",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := &api.PortalDocument{
ID: token,
TTL: 21600,
Portal: &api.Portal{
Username: username,
ID: resourceID,
Kubeconfig: &api.Kubeconfig{},
},
}
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
openShiftClustersClient.SetError(fmt.Errorf("sad"))
},
wantStatusCode: http.StatusInternalServerError,
wantBody: "Internal Server Error\n",
},
} {
t.Run(tt.name, func(t *testing.T) {
dbPortal, portalClient := testdatabase.NewFakePortal()
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
fixture := testdatabase.NewFixture().
WithOpenShiftClusters(dbOpenShiftClusters).
WithPortal(dbPortal)
checker := testdatabase.NewChecker()
if tt.fixtureChecker != nil {
tt.fixtureChecker(fixture, checker, openShiftClustersClient, portalClient)
}
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
r, err := http.NewRequest(http.MethodGet,
"https://localhost:8444"+resourceID+"/kubeconfig/proxy/test", nil)
if err != nil {
panic(err)
}
r.Header.Set("Authorization", "Bearer "+token)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mock_proxy.NewMockDialer(ctrl)
if tt.mocks != nil {
tt.mocks(dialer)
}
unauthenticatedRouter := &mux.Router{}
k := New(logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), nil, nil, dbOpenShiftClusters, dbPortal, dialer, &mux.Router{}, unauthenticatedRouter)
k.newToken = func() string { return token }
if tt.r != nil {
tt.r(r)
}
w := responsewriter.New(r)
unauthenticatedRouter.ServeHTTP(w, r)
openShiftClustersClient.SetError(nil)
portalClient.SetError(nil)
for _, err = range checker.CheckOpenShiftClusters(openShiftClustersClient) {
t.Error(err)
}
for _, err = range checker.CheckPortals(portalClient) {
t.Error(err)
}
resp := w.Response()
if resp.StatusCode != tt.wantStatusCode {
t.Error(resp.StatusCode)
}
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Error(resp.Header.Get("Content-Type"))
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != tt.wantBody {
t.Errorf("%q", string(b))
}
})
}
}

Просмотреть файл

@ -0,0 +1,394 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/gob"
"errors"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/coreos/go-oidc"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
"github.com/Azure/ARO-RP/pkg/util/deployment"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
)
const (
SessionName = "session"
SessionKeyExpires = "expires"
sessionKeyRedirectPath = "redirect_path"
sessionKeyState = "state"
SessionKeyUsername = "user_name"
SessionKeyGroups = "groups"
)
func init() {
gob.Register(time.Time{})
}
// AAD is responsible for ensuring that we have a valid login session with AAD.
type AAD interface {
AAD(http.Handler) http.Handler
Logout(string) http.Handler
Redirect(http.Handler) http.Handler
}
type oauther interface {
AuthCodeURL(string, ...oauth2.AuthCodeOption) string
Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error)
}
type Verifier interface {
Verify(context.Context, string) (oidctoken, error)
}
type idTokenVerifier struct {
*oidc.IDTokenVerifier
}
func (v *idTokenVerifier) Verify(ctx context.Context, rawIDToken string) (oidctoken, error) {
return v.IDTokenVerifier.Verify(ctx, rawIDToken)
}
type oidctoken interface {
Claims(interface{}) error
}
func NewVerifier(ctx context.Context, tenantID, clientID string) (Verifier, error) {
provider, err := oidc.NewProvider(ctx, "https://login.microsoftonline.com/"+tenantID+"/v2.0")
if err != nil {
return nil, err
}
return &idTokenVerifier{
provider.Verifier(&oidc.Config{
ClientID: clientID,
}),
}, nil
}
type claims struct {
Groups []string `json:"groups,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
}
type aad struct {
deploymentMode deployment.Mode
log *logrus.Entry
now func() time.Time
rt http.RoundTripper
tenantID string
clientID string
clientKey *rsa.PrivateKey
clientCerts []*x509.Certificate
store *sessions.CookieStore
oauther oauther
verifier Verifier
groupIDs []string
sessionTimeout time.Duration
}
func NewAAD(deploymentMode deployment.Mode,
log *logrus.Entry,
baseAccessLog *logrus.Entry,
hostname string,
sessionKey []byte,
tenantID string,
clientID string,
clientKey *rsa.PrivateKey,
clientCerts []*x509.Certificate,
groupIDs []string,
unauthenticatedRouter *mux.Router,
verifier Verifier) (AAD, error) {
if len(sessionKey) != 32 {
return nil, errors.New("invalid sessionKey")
}
a := &aad{
deploymentMode: deploymentMode,
log: log,
now: time.Now,
rt: http.DefaultTransport,
tenantID: tenantID,
clientID: clientID,
clientKey: clientKey,
clientCerts: clientCerts,
store: sessions.NewCookieStore(sessionKey),
oauther: &oauth2.Config{
ClientID: clientID,
Endpoint: microsoft.AzureADEndpoint(tenantID),
RedirectURL: "https://" + hostname + "/callback",
Scopes: []string{
"openid",
"profile",
},
},
verifier: verifier,
groupIDs: groupIDs,
sessionTimeout: time.Hour,
}
a.store.MaxAge(0)
a.store.Options.Secure = true
a.store.Options.HttpOnly = true
a.store.Options.SameSite = http.SameSiteLaxMode
unauthenticatedRouter.NewRoute().Methods(http.MethodGet).Path("/callback").Handler(Log(baseAccessLog)(http.HandlerFunc(a.callback)))
return a, nil
}
// AAD is the early stage handler which adds a username to the context if it
// can. It lets the request through regardless (this is so that failures can be
// logged).
func (a *aad) AAD(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := a.store.Get(r, SessionName)
if err != nil {
a.internalServerError(w, err)
return
}
expires, ok := session.Values[SessionKeyExpires].(time.Time)
if !ok || expires.Before(a.now()) {
h.ServeHTTP(w, r)
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, ContextKeyUsername, session.Values[SessionKeyUsername])
ctx = context.WithValue(ctx, ContextKeyGroups, session.Values[SessionKeyGroups])
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
// Redirect is the late stage (post logging) handler which redirects to AAD if
// there is no valid user.
func (a *aad) Redirect(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ctx.Value(ContextKeyUsername) != nil {
h.ServeHTTP(w, r)
return
}
a.redirect(w, r)
})
}
func (a *aad) Logout(url string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := a.store.Get(r, SessionName)
if err != nil {
a.internalServerError(w, err)
return
}
session.Values = nil
err = session.Save(r, w)
if err != nil {
a.internalServerError(w, err)
return
}
http.Redirect(w, r, url, http.StatusSeeOther)
})
}
func (a *aad) redirect(w http.ResponseWriter, r *http.Request) {
session, err := a.store.Get(r, SessionName)
if err != nil {
a.internalServerError(w, err)
return
}
path := r.URL.Path
if path == "/callback" {
path = "/"
}
state := uuid.NewV4().String()
session.Values = map[interface{}]interface{}{
sessionKeyRedirectPath: path,
sessionKeyState: state,
}
err = session.Save(r, w)
if err != nil {
a.internalServerError(w, err)
return
}
http.Redirect(w, r, a.oauther.AuthCodeURL(state), http.StatusTemporaryRedirect)
}
func (a *aad) callback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
session, err := a.store.Get(r, SessionName)
if err != nil {
a.internalServerError(w, err)
return
}
state, ok := session.Values[sessionKeyState].(string)
if !ok {
a.redirect(w, r)
return
}
delete(session.Values, sessionKeyState)
err = session.Save(r, w)
if err != nil {
a.internalServerError(w, err)
return
}
if r.FormValue("state") != state {
a.internalServerError(w, errors.New("state mismatch"))
return
}
if r.FormValue("error") != "" {
err := r.FormValue("error")
if r.FormValue("error_description") != "" {
err += ": " + r.FormValue("error_description")
}
a.internalServerError(w, errors.New(err))
return
}
cliCtx := context.WithValue(ctx, oauth2.HTTPClient, &http.Client{
Transport: roundtripper.RoundTripperFunc(a.clientAssertion),
})
token, err := a.oauther.Exchange(cliCtx, r.FormValue("code"))
if err != nil {
a.internalServerError(w, err)
return
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
a.internalServerError(w, errors.New("id_token not found"))
return
}
idToken, err := a.verifier.Verify(r.Context(), rawIDToken)
if err != nil {
a.internalServerError(w, err)
return
}
var claims claims
err = idToken.Claims(&claims)
if err != nil {
a.internalServerError(w, err)
return
}
if !GroupsIntersect(a.groupIDs, claims.Groups) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
redirectPath, ok := session.Values[sessionKeyRedirectPath].(string)
if !ok {
a.internalServerError(w, errors.New("redirect_path not found"))
return
}
delete(session.Values, sessionKeyRedirectPath)
session.Values[SessionKeyUsername] = claims.PreferredUsername
session.Values[SessionKeyGroups] = claims.Groups
session.Values[SessionKeyExpires] = a.now().Add(a.sessionTimeout)
err = session.Save(r, w)
if err != nil {
a.internalServerError(w, err)
return
}
http.Redirect(w, r, redirectPath, http.StatusTemporaryRedirect)
}
// clientAssertion adds a JWT client assertion according to
// https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-certificate-credentials
// Treating this as a RoundTripper is more hackery -- this is because the
// underlying oauth2 library is a little unextensible.
func (a *aad) clientAssertion(req *http.Request) (*http.Response, error) {
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, a.tenantID)
if err != nil {
return nil, err
}
sp, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, a.clientID, a.clientCerts[0], a.clientKey, "unused")
if err != nil {
return nil, err
}
s := &adal.ServicePrincipalCertificateSecret{
Certificate: a.clientCerts[0],
PrivateKey: a.clientKey,
}
err = req.ParseForm()
if err != nil {
return nil, err
}
err = s.SetAuthenticationValues(sp, &req.Form)
if err != nil {
return nil, err
}
form := req.Form.Encode()
req.Body = ioutil.NopCloser(strings.NewReader(form))
req.ContentLength = int64(len(form))
return a.rt.RoundTrip(req)
}
func (a *aad) internalServerError(w http.ResponseWriter, err error) {
a.log.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
func GroupsIntersect(as, bs []string) bool {
for _, a := range as {
for _, b := range bs {
if a == b {
return true
}
}
}
return false
}

Просмотреть файл

@ -0,0 +1,851 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/form3tech-oss/jwt-go"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/Azure/ARO-RP/pkg/util/deployment"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
)
var (
clientkey *rsa.PrivateKey
clientcerts []*x509.Certificate
)
func init() {
var err error
clientkey, clientcerts, err = utiltls.GenerateKeyAndCertificate("client", nil, nil, false, true)
if err != nil {
panic(err)
}
}
type noopOauther struct {
tokenMap map[string]interface{}
err error
}
func (noopOauther) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return "authcodeurl"
}
func (o *noopOauther) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
if o.err != nil {
return nil, o.err
}
t := oauth2.Token{}
return t.WithExtra(o.tokenMap), nil
}
type noopVerifier struct {
err error
}
func (v *noopVerifier) Verify(ctx context.Context, rawtoken string) (oidctoken, error) {
if v.err != nil {
return nil, v.err
}
return noopClaims(rawtoken), nil
}
type noopClaims []byte
func (c noopClaims) Claims(v interface{}) error {
return json.Unmarshal(c, v)
}
func TestNewAAD(t *testing.T) {
_, err := NewAAD(deployment.Production, nil, nil, "", nil, "", "", nil, nil, nil, nil, nil)
if err.Error() != "invalid sessionKey" {
t.Error(err)
}
}
func TestAAD(t *testing.T) {
for _, tt := range []struct {
name string
request func(*aad) (*http.Request, error)
wantStatusCode int
wantAuthenticated bool
wantUsername string
wantGroups []string
}{
{
name: "authenticated",
request: func(a *aad) (*http.Request, error) {
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
SessionKeyUsername: "username",
SessionKeyGroups: []string{"group1", "group2"},
SessionKeyExpires: time.Unix(1, 0),
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
}, nil
},
wantAuthenticated: true,
wantUsername: "username",
wantGroups: []string{"group1", "group2"},
},
{
name: "expired - not authenticated",
request: func(a *aad) (*http.Request, error) {
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
SessionKeyUsername: "username",
SessionKeyGroups: []string{"group1", "group2"},
SessionKeyExpires: time.Unix(-1, 0),
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
}, nil
},
},
{
name: "no cookie - not authenticated",
request: func(a *aad) (*http.Request, error) {
return &http.Request{}, nil
},
},
{
name: "invalid cookie",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Cookie": []string{"session=xxx"},
},
}, nil
},
wantStatusCode: http.StatusInternalServerError,
},
} {
t.Run(tt.name, func(t *testing.T) {
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
if err != nil {
t.Fatal(err)
}
a.(*aad).now = func() time.Time { return time.Unix(0, 0) }
var username string
var usernameok bool
var groups []string
var groupsok bool
h := a.AAD(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, usernameok = r.Context().Value(ContextKeyUsername).(string)
groups, groupsok = r.Context().Value(ContextKeyGroups).([]string)
}))
r, err := tt.request(a.(*aad))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if tt.wantStatusCode != 0 && w.Code != tt.wantStatusCode {
t.Error(w.Code)
}
if username != tt.wantUsername {
t.Error(username)
}
if usernameok != tt.wantAuthenticated {
t.Error(usernameok)
}
if !reflect.DeepEqual(groups, tt.wantGroups) {
t.Error(groups)
}
if groupsok != tt.wantAuthenticated {
t.Error(groupsok)
}
})
}
}
func TestRedirect(t *testing.T) {
for _, tt := range []struct {
name string
request func(*aad) (*http.Request, error)
wantStatusCode int
wantAuthenticated bool
}{
{
name: "authenticated",
request: func(a *aad) (*http.Request, error) {
ctx := context.Background()
ctx = context.WithValue(ctx, ContextKeyUsername, "user")
return http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
},
wantStatusCode: http.StatusOK,
wantAuthenticated: true,
},
{
name: "not authenticated",
request: func(a *aad) (*http.Request, error) {
ctx := context.Background()
return http.NewRequestWithContext(ctx, http.MethodGet, "/", nil)
},
wantStatusCode: http.StatusTemporaryRedirect,
},
{
name: "not authenticated",
request: func(a *aad) (*http.Request, error) {
ctx := context.Background()
return http.NewRequestWithContext(ctx, http.MethodGet, "/callback", nil)
},
wantStatusCode: http.StatusTemporaryRedirect,
},
{
name: "invalid cookie",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Cookie": []string{"session=xxx"},
},
}, nil
},
wantStatusCode: http.StatusInternalServerError,
},
} {
t.Run(tt.name, func(t *testing.T) {
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
if err != nil {
t.Fatal(err)
}
var authenticated bool
h := a.Redirect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, authenticated = r.Context().Value(ContextKeyUsername).(string)
}))
r, err := tt.request(a.(*aad))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != tt.wantStatusCode {
t.Error(w.Code)
}
if authenticated != tt.wantAuthenticated {
t.Fatal(authenticated)
}
if tt.wantStatusCode == http.StatusInternalServerError {
return
}
if !tt.wantAuthenticated {
if !strings.HasPrefix(w.Header().Get("Location"), "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=&redirect_uri=https%3A%2F%2F%2Fcallback&response_type=code&scope=openid+profile&state=") {
t.Error(w.Header().Get("Location"))
}
var m map[interface{}]interface{}
cookies := w.Result().Cookies()
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
if err != nil {
t.Fatal(err)
}
if len(m) != 2 {
t.Error(len(m))
}
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
redirectPath != "/" {
t.Error(m[sessionKeyRedirectPath])
}
if state, ok := m[sessionKeyState].(string); !ok ||
uuid.FromStringOrNil(state) == uuid.Nil {
t.Error(m[sessionKeyState])
}
}
})
}
}
func TestLogout(t *testing.T) {
for _, tt := range []struct {
name string
request func(*aad) (*http.Request, error)
wantStatusCode int
}{
{
name: "authenticated",
request: func(a *aad) (*http.Request, error) {
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
SessionKeyUsername: "username",
SessionKeyGroups: []string{"group1", "group2"},
SessionKeyExpires: time.Unix(1, 0),
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
}, nil
},
wantStatusCode: http.StatusSeeOther,
},
{
name: "no cookie - not authenticated",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
URL: &url.URL{},
}, nil
},
wantStatusCode: http.StatusSeeOther,
},
{
name: "invalid cookie",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Cookie": []string{"session=xxx"},
},
}, nil
},
wantStatusCode: http.StatusInternalServerError,
},
} {
t.Run(tt.name, func(t *testing.T) {
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", "", nil, nil, nil, mux.NewRouter(), nil)
if err != nil {
t.Fatal(err)
}
h := a.Logout("/bye")
r, err := tt.request(a.(*aad))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != tt.wantStatusCode {
t.Error(w.Code)
}
if tt.wantStatusCode == http.StatusInternalServerError {
return
}
if w.Header().Get("Location") != "/bye" {
t.Error(w.Header().Get("Location"))
}
var m map[interface{}]interface{}
cookies := w.Result().Cookies()
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
if err != nil {
t.Fatal(err)
}
if len(m) != 0 {
t.Error(len(m))
}
})
}
}
func TestCallback(t *testing.T) {
clientID := "00000000-0000-0000-0000-000000000000"
groups := []string{
"00000000-0000-0000-0000-000000000001",
}
username := "user"
idToken, err := json.Marshal(claims{
Groups: groups,
PreferredUsername: username,
})
if err != nil {
t.Fatal(err)
}
for _, tt := range []struct {
name string
request func(*aad) (*http.Request, error)
oauther oauther
verifier Verifier
wantAuthenticated bool
wantError string
wantForbidden bool
}{
{
name: "success",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
tokenMap: map[string]interface{}{
"id_token": string(idToken),
},
},
verifier: &noopVerifier{},
wantAuthenticated: true,
},
{
name: "fail - invalid cookie",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Cookie": []string{"session=xxx"},
},
}, nil
},
wantError: "Internal Server Error\n",
},
{
name: "fail - corrupt sessionKeyState",
request: func(a *aad) (*http.Request, error) {
return &http.Request{
URL: &url.URL{},
}, nil
},
oauther: &noopOauther{},
},
{
name: "fail - state mismatch",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{"bad"},
},
}, nil
},
wantError: "Internal Server Error\n",
},
{
name: "fail - error returned",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
"error": []string{"bad things happened."},
"error_description": []string{"really bad things."},
},
}, nil
},
wantError: "Internal Server Error\n",
},
{
name: "fail - oauther failed",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
err: fmt.Errorf("failed"),
},
wantError: "Internal Server Error\n",
},
{
name: "fail - no idtoken",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{},
wantError: "Internal Server Error\n",
},
{
name: "fail - verifier error",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
tokenMap: map[string]interface{}{"id_token": ""},
},
verifier: &noopVerifier{
err: fmt.Errorf("failed"),
},
wantError: "Internal Server Error\n",
},
{
name: "fail - invalid claims",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
tokenMap: map[string]interface{}{
"id_token": "",
},
},
verifier: &noopVerifier{},
wantError: "Internal Server Error\n",
},
{
name: "fail - group mismatch",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
sessionKeyRedirectPath: "/",
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
tokenMap: map[string]interface{}{
"id_token": "null",
},
},
verifier: &noopVerifier{},
wantForbidden: true,
},
{
name: "fail - missing redirect_path",
request: func(a *aad) (*http.Request, error) {
uuid := uuid.NewV4().String()
cookie, err := securecookie.EncodeMulti(SessionName, map[interface{}]interface{}{
sessionKeyState: uuid,
}, a.store.Codecs...)
if err != nil {
return nil, err
}
return &http.Request{
URL: &url.URL{},
Header: http.Header{
"Cookie": []string{SessionName + "=" + cookie},
},
Form: url.Values{
"state": []string{uuid},
},
}, nil
},
oauther: &noopOauther{
tokenMap: map[string]interface{}{
"id_token": string(idToken),
},
},
verifier: &noopVerifier{},
wantError: "Internal Server Error\n",
},
} {
t.Run(tt.name, func(t *testing.T) {
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", clientID, clientkey, clientcerts, groups, mux.NewRouter(), tt.verifier)
if err != nil {
t.Fatal(err)
}
a.(*aad).now = func() time.Time { return time.Unix(0, 0) }
a.(*aad).oauther = tt.oauther
r, err := tt.request(a.(*aad))
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
a.(*aad).callback(w, r)
if tt.wantError != "" {
if w.Code != http.StatusInternalServerError {
t.Error(w.Code)
}
if w.Body.String() != tt.wantError {
t.Error(w.Body.String())
}
return
}
var m map[interface{}]interface{}
cookies := w.Result().Cookies()
err = securecookie.DecodeMulti(SessionName, cookies[len(cookies)-1].Value, &m, a.(*aad).store.Codecs...)
if err != nil {
t.Fatal(err)
}
switch {
case tt.wantAuthenticated:
if w.Code != http.StatusTemporaryRedirect {
t.Error(w.Code)
}
if w.Header().Get("Location") != "/" {
t.Error(w.Header().Get("Location"))
}
if len(m) != 3 {
t.Error(len(m))
}
if expires, ok := m[SessionKeyExpires].(time.Time); !ok ||
expires != time.Unix(3600, 0) {
t.Error(m[SessionKeyExpires])
}
if sessionGroups, ok := m[SessionKeyGroups].([]string); !ok ||
!reflect.DeepEqual(sessionGroups, groups) {
t.Error(m[SessionKeyGroups])
}
if sessionUsername, ok := m[SessionKeyUsername].(string); !ok ||
sessionUsername != username {
t.Error(m[SessionKeyUsername])
}
case tt.wantForbidden:
if w.Code != http.StatusForbidden {
t.Error(w.Code)
}
if w.Header().Get("Location") != "/" {
t.Error(w.Header().Get("Location"))
}
if len(m) != 1 {
t.Error(len(m))
}
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
redirectPath != "/" {
t.Error(m[sessionKeyRedirectPath])
}
default:
if w.Code != http.StatusTemporaryRedirect {
t.Error(w.Code)
}
if w.Header().Get("Location") != "/authcodeurl" {
t.Error(w.Header().Get("Location"))
}
if len(m) != 2 {
t.Error(len(m))
}
if redirectPath, ok := m[sessionKeyRedirectPath].(string); !ok ||
redirectPath != "" {
t.Error(m[sessionKeyRedirectPath])
}
if state, ok := m[sessionKeyState].(string); !ok ||
uuid.FromStringOrNil(state) == uuid.Nil {
t.Error(m[sessionKeyState])
}
}
})
}
}
func TestClientAssertion(t *testing.T) {
clientID := "00000000-0000-0000-0000-000000000000"
a, err := NewAAD(deployment.Production, logrus.NewEntry(logrus.StandardLogger()), logrus.NewEntry(logrus.StandardLogger()), "", make([]byte, 32), "", clientID, clientkey, clientcerts, nil, mux.NewRouter(), nil)
if err != nil {
t.Fatal(err)
}
a.(*aad).rt = roundtripper.RoundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, nil
})
req, err := http.NewRequest(http.MethodGet, "https://localhost/", nil)
if err != nil {
t.Fatal(err)
}
req.Form = url.Values{"test": []string{"value"}}
_, err = a.(*aad).clientAssertion(req)
if err != nil {
t.Fatal(err)
}
if req.Form.Get("test") != "value" {
t.Error(req.Form.Get("test"))
}
if req.Form.Get("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
t.Error(req.Form.Get("client_assertion_type"))
}
p := &jwt.Parser{}
_, err = p.Parse(req.Form.Get("client_assertion"), func(*jwt.Token) (interface{}, error) {
return &clientkey.PublicKey, nil
})
if err != nil {
t.Fatal(err)
}
}

Просмотреть файл

@ -0,0 +1,50 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"net/http"
"strings"
uuid "github.com/satori/go.uuid"
"github.com/Azure/ARO-RP/pkg/database"
)
// Bearer validates a Bearer token and adds the corresponding username to the
// context if it checks out. It lets the request through regardless (this is so
// that failures can be logged).
func Bearer(dbPortal database.Portal) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authorization := r.Header.Get("Authorization")
if !strings.HasPrefix(authorization, "Bearer ") {
h.ServeHTTP(w, r)
return
}
token, err := uuid.FromString(strings.TrimPrefix(authorization, "Bearer "))
if err != nil {
h.ServeHTTP(w, r)
return
}
portalDoc, err := dbPortal.Get(ctx, token.String())
if err != nil {
h.ServeHTTP(w, r)
return
}
ctx = context.WithValue(ctx, ContextKeyUsername, portalDoc.Portal.Username)
ctx = context.WithValue(ctx, ContextKeyPortalDoc, portalDoc)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
}

Просмотреть файл

@ -0,0 +1,123 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Azure/ARO-RP/pkg/api"
testdatabase "github.com/Azure/ARO-RP/test/database"
)
func TestBearer(t *testing.T) {
for _, tt := range []struct {
name string
fixture func(*testdatabase.Fixture)
request func() (*http.Request, error)
wantAuthenticated bool
wantUsername string
}{
{
name: "authenticated",
fixture: func(fixture *testdatabase.Fixture) {
fixture.AddPortalDocuments(&api.PortalDocument{
ID: "00000000-0000-0000-0000-000000000000",
Portal: &api.Portal{
Username: "username",
},
})
},
request: func() (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Authorization": []string{"Bearer 00000000-0000-0000-0000-000000000000"},
},
}, nil
},
wantAuthenticated: true,
wantUsername: "username",
},
{
name: "not authenticated - no header",
request: func() (*http.Request, error) {
return &http.Request{}, nil
},
},
{
name: "not authenticated - bad header",
request: func() (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Authorization": []string{"Bearer bad"},
},
}, nil
},
},
{
name: "not authenticated - berarer not found",
fixture: func(fixture *testdatabase.Fixture) {
fixture.AddPortalDocuments(&api.PortalDocument{
ID: "00000000-0000-0000-0000-000000000000",
Portal: &api.Portal{
Username: "username",
},
})
},
request: func() (*http.Request, error) {
return &http.Request{
Header: http.Header{
"Authorization": []string{"Bearer 10000000-0000-0000-0000-000000000000"},
},
}, nil
},
},
} {
dbPortal, _ := testdatabase.NewFakePortal()
fixture := testdatabase.NewFixture().
WithPortal(dbPortal)
if tt.fixture != nil {
tt.fixture(fixture)
}
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
var username string
var usernameok bool
var portaldoc *api.PortalDocument
var portaldocok bool
h := Bearer(dbPortal)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, usernameok = r.Context().Value(ContextKeyUsername).(string)
portaldoc, portaldocok = r.Context().Value(ContextKeyPortalDoc).(*api.PortalDocument)
}))
r, err := tt.request()
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if username != tt.wantUsername {
t.Error(username)
}
if usernameok != tt.wantAuthenticated {
t.Error(usernameok)
}
if portaldocok != tt.wantAuthenticated {
t.Error(portaldocok)
}
if tt.wantAuthenticated && portaldoc.Portal.Username != username {
t.Error(portaldoc.Portal.Username)
}
}
}

Просмотреть файл

@ -0,0 +1,88 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bufio"
"io"
"net"
"net/http"
"time"
"github.com/sirupsen/logrus"
utillog "github.com/Azure/ARO-RP/pkg/util/log"
)
type logResponseWriter struct {
http.ResponseWriter
statusCode int
bytes int
}
func (w *logResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker := w.ResponseWriter.(http.Hijacker)
return hijacker.Hijack()
}
func (w *logResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
w.bytes += n
return n, err
}
func (w *logResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
w.statusCode = statusCode
}
type logReadCloser struct {
io.ReadCloser
bytes int
}
func (rc *logReadCloser) Read(b []byte) (int, error) {
n, err := rc.ReadCloser.Read(b)
rc.bytes += n
return n, err
}
func Log(baseLog *logrus.Entry) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t := time.Now()
r.Body = &logReadCloser{ReadCloser: r.Body}
w = &logResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
log := baseLog
log = utillog.EnrichWithPath(log, r.URL.Path)
username, _ := r.Context().Value(ContextKeyUsername).(string)
log = log.WithFields(logrus.Fields{
"request_method": r.Method,
"request_path": r.URL.Path,
"request_proto": r.Proto,
"request_remote_addr": r.RemoteAddr,
"request_user_agent": r.UserAgent(),
"username": username,
})
log.Print("read request")
defer func() {
log.WithFields(logrus.Fields{
"body_read_bytes": r.Body.(*logReadCloser).bytes,
"body_written_bytes": w.(*logResponseWriter).bytes,
"duration": time.Since(t).Seconds(),
"response_status_code": w.(*logResponseWriter).statusCode,
}).Print("sent response")
}()
h.ServeHTTP(w, r)
})
}
}

Просмотреть файл

@ -0,0 +1,82 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/onsi/gomega"
"github.com/onsi/gomega/types"
"github.com/sirupsen/logrus"
testlog "github.com/Azure/ARO-RP/test/util/log"
)
func TestLog(t *testing.T) {
h, log := testlog.New()
ctx := context.WithValue(context.Background(), ContextKeyUsername, "username")
r, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://localhost/", strings.NewReader("body"))
if err != nil {
t.Fatal(err)
}
r.RemoteAddr = "127.0.0.1:1234"
r.Header.Set("User-Agent", "user-agent")
w := httptest.NewRecorder()
Log(log)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL = nil // mutate the request
_ = w.(http.Hijacker) // must implement http.Hijacker
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, r.Body)
})).ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Error(w.Code)
}
expected := []map[string]types.GomegaMatcher{
{
"msg": gomega.Equal("read request"),
"level": gomega.Equal(logrus.InfoLevel),
"request_method": gomega.Equal("POST"),
"request_path": gomega.Equal("/"),
"request_proto": gomega.Equal("HTTP/1.1"),
"request_remote_addr": gomega.Equal("127.0.0.1:1234"),
"request_user_agent": gomega.Equal("user-agent"),
"username": gomega.Equal("username"),
},
{
"msg": gomega.Equal("sent response"),
"level": gomega.Equal(logrus.InfoLevel),
"body_read_bytes": gomega.Equal(4),
"body_written_bytes": gomega.Equal(4),
"response_status_code": gomega.Equal(http.StatusOK),
"request_method": gomega.Equal("POST"),
"request_path": gomega.Equal("/"),
"request_proto": gomega.Equal("HTTP/1.1"),
"request_remote_addr": gomega.Equal("127.0.0.1:1234"),
"request_user_agent": gomega.Equal("user-agent"),
"username": gomega.Equal("username"),
},
}
err = testlog.AssertLoggingOutput(h, expected)
if err != nil {
t.Error(err)
}
for _, e := range h.Entries {
fmt.Println(e)
}
}

Просмотреть файл

@ -0,0 +1,12 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
type contextKey int
const (
ContextKeyUsername contextKey = iota
ContextKeyGroups
ContextKeyPortalDoc
)

Просмотреть файл

@ -0,0 +1,26 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
"runtime/debug"
"github.com/sirupsen/logrus"
)
func Panic(log *logrus.Entry) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if e := recover(); e != nil {
log.Errorf("panic: %#v\n%s\n", e, string(debug.Stack()))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
h.ServeHTTP(w, r)
})
}
}

Просмотреть файл

@ -0,0 +1,65 @@
package middleware
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/onsi/gomega"
"github.com/onsi/gomega/types"
"github.com/sirupsen/logrus"
testlog "github.com/Azure/ARO-RP/test/util/log"
)
func TestPanic(t *testing.T) {
for _, tt := range []struct {
name string
panictext string
}{
{
name: "ok",
},
{
name: "panic",
panictext: "random error",
},
} {
h, log := testlog.New()
w := httptest.NewRecorder()
Panic(log)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tt.panictext != "" {
panic(tt.panictext)
}
})).ServeHTTP(w, nil)
var expected []map[string]types.GomegaMatcher
if tt.panictext == "" {
if w.Code != http.StatusOK {
t.Error(w.Code)
}
} else {
if w.Code != http.StatusInternalServerError {
t.Error(w.Code)
}
expected = []map[string]types.GomegaMatcher{
{
"msg": gomega.MatchRegexp(regexp.QuoteMeta(tt.panictext)),
"level": gomega.Equal(logrus.ErrorLevel),
},
}
}
err := testlog.AssertLoggingOutput(h, expected)
if err != nil {
t.Error(err)
}
}
}

286
pkg/portal/portal.go Normal file
Просмотреть файл

@ -0,0 +1,286 @@
package portal
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/json"
"html/template"
"log"
"net"
"net/http"
"sort"
"time"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/env"
frontendmiddleware "github.com/Azure/ARO-RP/pkg/frontend/middleware"
"github.com/Azure/ARO-RP/pkg/portal/kubeconfig"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/portal/prometheus"
"github.com/Azure/ARO-RP/pkg/portal/ssh"
"github.com/Azure/ARO-RP/pkg/proxy"
)
type Runnable interface {
Run(context.Context) error
}
type portal struct {
env env.Core
log *logrus.Entry
baseAccessLog *logrus.Entry
l net.Listener
sshl net.Listener
verifier middleware.Verifier
hostname string
servingKey *rsa.PrivateKey
servingCerts []*x509.Certificate
clientID string
clientKey *rsa.PrivateKey
clientCerts []*x509.Certificate
sessionKey []byte
sshKey *rsa.PrivateKey
groupIDs []string
elevatedGroupIDs []string
dbPortal database.Portal
dbOpenShiftClusters database.OpenShiftClusters
dialer proxy.Dialer
t *template.Template
aad middleware.AAD
}
func NewPortal(env env.Core,
log *logrus.Entry,
baseAccessLog *logrus.Entry,
l net.Listener,
sshl net.Listener,
verifier middleware.Verifier,
hostname string,
servingKey *rsa.PrivateKey,
servingCerts []*x509.Certificate,
clientID string,
clientKey *rsa.PrivateKey,
clientCerts []*x509.Certificate,
sessionKey []byte,
sshKey *rsa.PrivateKey,
groupIDs []string,
elevatedGroupIDs []string,
dbOpenShiftClusters database.OpenShiftClusters,
dbPortal database.Portal,
dialer proxy.Dialer) Runnable {
return &portal{
env: env,
log: log,
baseAccessLog: baseAccessLog,
l: l,
sshl: sshl,
verifier: verifier,
hostname: hostname,
servingKey: servingKey,
servingCerts: servingCerts,
clientID: clientID,
clientKey: clientKey,
clientCerts: clientCerts,
sessionKey: sessionKey,
sshKey: sshKey,
groupIDs: groupIDs,
elevatedGroupIDs: elevatedGroupIDs,
dbOpenShiftClusters: dbOpenShiftClusters,
dbPortal: dbPortal,
dialer: dialer,
}
}
func (p *portal) Run(ctx context.Context) error {
asset, err := Asset("index.html")
if err != nil {
return err
}
p.t, err = template.New("index.html").Parse(string(asset))
if err != nil {
return err
}
config := &tls.Config{
Certificates: []tls.Certificate{
{
PrivateKey: p.servingKey,
},
},
NextProtos: []string{"h2", "http/1.1"},
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
},
PreferServerCipherSuites: true,
SessionTicketsDisabled: true,
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{
tls.CurveP256,
tls.X25519,
},
}
for _, cert := range p.servingCerts {
config.Certificates[0].Certificate = append(config.Certificates[0].Certificate, cert.Raw)
}
r := mux.NewRouter()
r.Use(middleware.Panic(p.log))
unauthenticatedRouter := r.NewRoute().Subrouter()
p.unauthenticatedRoutes(unauthenticatedRouter)
p.aad, err = middleware.NewAAD(p.env.DeploymentMode(), p.log, p.baseAccessLog, p.hostname, p.sessionKey, p.env.TenantID(), p.clientID, p.clientKey, p.clientCerts, p.groupIDs, unauthenticatedRouter, p.verifier)
if err != nil {
return err
}
aadAuthenticatedRouter := r.NewRoute().Subrouter()
aadAuthenticatedRouter.Use(p.aad.AAD)
aadAuthenticatedRouter.Use(middleware.Log(p.baseAccessLog))
aadAuthenticatedRouter.Use(p.aad.Redirect)
aadAuthenticatedRouter.Use(csrf.Protect(p.sessionKey, csrf.SameSite(csrf.SameSiteStrictMode), csrf.MaxAge(0)))
p.aadAuthenticatedRoutes(aadAuthenticatedRouter)
ssh, err := ssh.New(p.env, p.log, p.baseAccessLog, p.sshl, p.sshKey, p.elevatedGroupIDs, p.dbOpenShiftClusters, p.dbPortal, p.dialer, aadAuthenticatedRouter)
if err != nil {
return err
}
err = ssh.Run()
if err != nil {
return err
}
kubeconfig.New(p.log, p.baseAccessLog, p.servingCerts[0], p.elevatedGroupIDs, p.dbOpenShiftClusters, p.dbPortal, p.dialer, aadAuthenticatedRouter, unauthenticatedRouter)
prometheus.New(p.log, p.dbOpenShiftClusters, p.dialer, aadAuthenticatedRouter)
s := &http.Server{
Handler: frontendmiddleware.Lowercase(r),
ReadTimeout: 10 * time.Second,
IdleTimeout: 2 * time.Minute,
ErrorLog: log.New(p.log.Writer(), "", 0),
BaseContext: func(net.Listener) context.Context { return ctx },
}
return s.Serve(tls.NewListener(p.l, config))
}
func (p *portal) unauthenticatedRoutes(r *mux.Router) {
logger := middleware.Log(p.baseAccessLog)
r.NewRoute().Methods(http.MethodGet).Path("/healthz/ready").Handler(logger(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})))
}
func (p *portal) aadAuthenticatedRoutes(r *mux.Router) {
for _, name := range AssetNames() {
if name == "index.html" {
continue
}
r.NewRoute().Methods(http.MethodGet).Path("/" + name).HandlerFunc(p.serve(name))
}
r.NewRoute().Methods(http.MethodGet).Path("/").HandlerFunc(p.index)
r.NewRoute().Methods(http.MethodGet).Path("/api/clusters").HandlerFunc(p.clusters)
r.NewRoute().Methods(http.MethodPost).Path("/api/logout").Handler(p.aad.Logout("/"))
}
func (p *portal) serve(path string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
b, err := Asset(path)
if err != nil {
p.internalServerError(w, err)
return
}
http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(b))
}
}
func (p *portal) index(w http.ResponseWriter, r *http.Request) {
buf := &bytes.Buffer{}
err := p.t.ExecuteTemplate(buf, "index.html", map[string]interface{}{
"location": p.env.Location(),
csrf.TemplateTag: csrf.TemplateField(r),
})
if err != nil {
p.internalServerError(w, err)
return
}
http.ServeContent(w, r, "index.html", time.Time{}, bytes.NewReader(buf.Bytes()))
}
func (p *portal) clusters(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
docs, err := p.dbOpenShiftClusters.ListAll(ctx)
if err != nil {
p.internalServerError(w, err)
return
}
clusters := make([]string, 0, len(docs.OpenShiftClusterDocuments))
for _, doc := range docs.OpenShiftClusterDocuments {
ps := doc.OpenShiftCluster.Properties.ProvisioningState
fps := doc.OpenShiftCluster.Properties.FailedProvisioningState
switch {
case ps == api.ProvisioningStateCreating,
ps == api.ProvisioningStateDeleting,
ps == api.ProvisioningStateFailed &&
(fps == api.ProvisioningStateCreating ||
fps == api.ProvisioningStateDeleting):
default:
clusters = append(clusters, doc.OpenShiftCluster.ID)
}
}
sort.Strings(clusters)
b, err := json.MarshalIndent(clusters, "", " ")
if err != nil {
p.internalServerError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(b)
}
func (p *portal) internalServerError(w http.ResponseWriter, err error) {
p.log.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

67
pkg/portal/portal_test.go Normal file
Просмотреть файл

@ -0,0 +1,67 @@
package portal
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/Azure/ARO-RP/pkg/api"
testdatabase "github.com/Azure/ARO-RP/test/database"
)
func TestClusters(t *testing.T) {
dbOpenShiftClusters, _ := testdatabase.NewFakeOpenShiftClusters()
fixture := testdatabase.NewFixture().
WithOpenShiftClusters(dbOpenShiftClusters)
fixture.AddOpenShiftClusterDocuments(&api.OpenShiftClusterDocument{
Key: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourcegroupname/providers/microsoft.redhatopenshift/openshiftclusters/succeeded",
OpenShiftCluster: &api.OpenShiftCluster{
ID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/succeeded",
Properties: api.OpenShiftClusterProperties{
ProvisioningState: api.ProvisioningStateSucceeded,
},
},
}, &api.OpenShiftClusterDocument{
Key: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourcegroupname/providers/microsoft.redhatopenshift/openshiftclusters/creating",
OpenShiftCluster: &api.OpenShiftCluster{
ID: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/creating",
Properties: api.OpenShiftClusterProperties{
ProvisioningState: api.ProvisioningStateCreating,
},
},
})
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
p := &portal{
dbOpenShiftClusters: dbOpenShiftClusters,
}
w := httptest.NewRecorder()
p.clusters(w, &http.Request{})
if w.Header().Get("Content-Type") != "application/json" {
t.Error(w.Header().Get("Content-Type"))
}
var r []string
err = json.NewDecoder(w.Body).Decode(&r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(r, []string{"/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/succeeded"}) {
t.Error(r)
}
}

Просмотреть файл

@ -0,0 +1,57 @@
package prometheus
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"log"
"net/http"
"net/http/httputil"
"time"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/portal/util/clientcache"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
)
type prometheus struct {
log *logrus.Entry
dbOpenShiftClusters database.OpenShiftClusters
dialer proxy.Dialer
clientCache clientcache.ClientCache
}
func New(baseLog *logrus.Entry,
dbOpenShiftClusters database.OpenShiftClusters,
dialer proxy.Dialer,
aadAuthenticatedRouter *mux.Router) *prometheus {
p := &prometheus{
log: baseLog,
dbOpenShiftClusters: dbOpenShiftClusters,
dialer: dialer,
clientCache: clientcache.New(time.Hour),
}
rp := &httputil.ReverseProxy{
Director: p.director,
Transport: roundtripper.RoundTripperFunc(p.roundTripper),
ModifyResponse: p.modifyResponse,
ErrorLog: log.New(p.log.Writer(), "", 0),
}
aadAuthenticatedRouter.NewRoute().Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/prometheus").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path += "/"
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
})
aadAuthenticatedRouter.NewRoute().PathPrefix("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/prometheus/").Handler(rp)
return p
}

Просмотреть файл

@ -0,0 +1,171 @@
package prometheus
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"io/ioutil"
"mime"
"net"
"net/http"
"strconv"
"strings"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"github.com/Azure/ARO-RP/pkg/api/validate"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
"github.com/Azure/ARO-RP/pkg/util/portforward"
"github.com/Azure/ARO-RP/pkg/util/restconfig"
)
// Unfortunately the signature of httputil.ReverseProxy.Director does not allow
// us to return errors. We get around this limitation slightly naughtily by
// storing return information in the request context.
type contextKey int
const (
contextKeyClient contextKey = iota
contextKeyResponse
)
func (p *prometheus) director(r *http.Request) {
ctx := r.Context()
resourceID := strings.Join(strings.Split(r.URL.Path, "/")[:9], "/")
if !validate.RxClusterID.MatchString(resourceID) {
p.error(r, http.StatusBadRequest, nil)
return
}
cli := p.clientCache.Get(resourceID)
if cli == nil {
var err error
cli, err = p.cli(ctx, resourceID)
if err != nil {
p.error(r, http.StatusInternalServerError, err)
return
}
p.clientCache.Put(resourceID, cli)
}
r.RequestURI = ""
r.URL.Scheme = "http"
r.URL.Host = "prometheus-k8s-0:9090"
r.URL.Path = "/" + strings.Join(strings.Split(r.URL.Path, "/")[10:], "/")
r.Header.Del("Cookie")
r.Header.Del("Referer")
r.Host = r.URL.Host
// http.Request.WithContext returns a copy of the original Request with the
// new context, but we have no way to return it, so we overwrite our
// existing request.
*r = *r.WithContext(context.WithValue(ctx, contextKeyClient, cli))
}
func (p *prometheus) cli(ctx context.Context, resourceID string) (*http.Client, error) {
openShiftDoc, err := p.dbOpenShiftClusters.Get(ctx, resourceID)
if err != nil {
return nil, err
}
restconfig, err := restconfig.RestConfig(p.dialer, openShiftDoc.OpenShiftCluster)
if err != nil {
return nil, err
}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return portforward.DialContext(ctx, p.log, restconfig, "openshift-monitoring", "prometheus-k8s-0", "9090")
},
},
}, nil
}
func (p *prometheus) roundTripper(r *http.Request) (*http.Response, error) {
if resp, ok := r.Context().Value(contextKeyResponse).(*http.Response); ok {
return resp, nil
}
cli := r.Context().Value(contextKeyClient).(*http.Client)
return cli.Do(r)
}
func (p *prometheus) modifyResponse(r *http.Response) error {
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
if mediaType != "text/html" {
return nil
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
return err
}
buf := &bytes.Buffer{}
n, err := html.Parse(bytes.NewReader(b))
if err != nil {
buf.Write(b)
} else {
walk(n, makeRelative)
err = html.Render(buf, n)
if err != nil {
return err
}
r.Header.Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10))
}
r.Body = ioutil.NopCloser(buf)
return nil
}
func makeRelative(n *html.Node) {
switch n.DataAtom {
case atom.A, atom.Link:
for i, attr := range n.Attr {
if attr.Namespace == "" && attr.Key == "href" && strings.HasPrefix(n.Attr[i].Val, "/") {
n.Attr[i].Val = "." + n.Attr[i].Val
}
}
case atom.Script:
for i, attr := range n.Attr {
if attr.Namespace == "" && attr.Key == "src" && strings.HasPrefix(n.Attr[i].Val, "/") {
n.Attr[i].Val = "." + n.Attr[i].Val
}
}
if len(n.Attr) == 0 {
n.FirstChild.Data = strings.Replace(n.FirstChild.Data, `var PATH_PREFIX = "";`, `var PATH_PREFIX = ".";`, 1)
}
}
}
func walk(n *html.Node, f func(*html.Node)) {
f(n)
for c := n.FirstChild; c != nil; c = c.NextSibling {
walk(c, f)
}
}
func (p *prometheus) error(r *http.Request, statusCode int, err error) {
if err != nil {
p.log.Print(err)
}
w := responsewriter.New(r)
http.Error(w, http.StatusText(statusCode), statusCode)
*r = *r.WithContext(context.WithValue(r.Context(), contextKeyResponse, w.Response()))
}

Просмотреть файл

@ -0,0 +1,375 @@
package prometheus
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bufio"
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"strconv"
"strings"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
v1 "k8s.io/client-go/tools/clientcmd/api/v1"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
"github.com/Azure/ARO-RP/pkg/util/portforward"
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
testdatabase "github.com/Azure/ARO-RP/test/database"
"github.com/Azure/ARO-RP/test/util/listener"
)
type conn struct {
net.Conn
rw *bufio.ReadWriter
}
func (c *conn) Read(b []byte) (int, error) {
return c.rw.Read(b)
}
// fakeServer returns a test listener for an HTTPS server which validates its
// client and forwards a SPDY request to a second HTTP server which echos back
// the request it received.
func fakeServer(cacerts []*x509.Certificate, serverkey *rsa.PrivateKey, servercerts []*x509.Certificate) *listener.Listener {
podl := listener.NewListener()
go func() {
_ = http.Serve(podl, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := httputil.DumpRequest(r, true)
_, _ = w.Write(b)
}))
}()
kubel := listener.NewListener()
pool := x509.NewCertPool()
pool.AddCert(cacerts[0])
go func() {
_ = http.Serve(tls.NewListener(kubel, &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{
servercerts[0].Raw,
},
PrivateKey: serverkey,
},
},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
}), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
w.Header().Add(httpstream.HeaderUpgrade, spdy.HeaderSpdy31)
w.WriteHeader(http.StatusSwitchingProtocols)
hijacker := w.(http.Hijacker)
c, buf, err := hijacker.Hijack()
if err != nil {
panic(err)
}
var mu sync.Mutex
var dataStream, errorStream httpstream.Stream
var dataReplySent, errorReplySent <-chan struct{}
var serverconn httpstream.Connection
serverconn, err = spdy.NewServerConnection(&conn{Conn: c, rw: buf}, func(stream httpstream.Stream, replySent <-chan struct{}) error {
mu.Lock()
defer mu.Unlock()
switch stream.Headers().Get(corev1.StreamType) {
case corev1.StreamTypeData:
if dataStream != nil {
return fmt.Errorf("dataStream already set")
}
dataStream = stream
dataReplySent = replySent
case corev1.StreamTypeError:
if errorStream != nil {
return fmt.Errorf("errorStream already set")
}
errorStream = stream
errorReplySent = replySent
}
if dataStream != nil && errorStream != nil {
go func() {
<-dataReplySent
<-errorReplySent
podl.Enqueue(portforward.NewStreamConn(nil, serverconn, dataStream, errorStream))
}()
}
return nil
})
if err != nil {
panic(err)
}
}))
podl.Close()
}()
return kubel
}
func testKubeconfig(cacerts []*x509.Certificate, clientkey *rsa.PrivateKey, clientcerts []*x509.Certificate) ([]byte, error) {
kc := &v1.Config{
Clusters: []v1.NamedCluster{
{
Cluster: v1.Cluster{
Server: "https://kubernetes:6443",
},
},
},
AuthInfos: []v1.NamedAuthInfo{
{},
},
}
var err error
kc.AuthInfos[0].AuthInfo.ClientKeyData, err = utiltls.PrivateKeyAsBytes(clientkey)
if err != nil {
return nil, err
}
kc.AuthInfos[0].AuthInfo.ClientCertificateData, err = utiltls.CertAsBytes(clientcerts[0])
if err != nil {
return nil, err
}
kc.Clusters[0].Cluster.CertificateAuthorityData, err = utiltls.CertAsBytes(cacerts[0])
if err != nil {
return nil, err
}
return json.Marshal(kc)
}
func TestProxy(t *testing.T) {
ctx := context.Background()
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
privateEndpointIP := "1.2.3.4"
cakey, cacerts, err := utiltls.GenerateKeyAndCertificate("ca", nil, nil, true, false)
if err != nil {
t.Fatal(err)
}
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("kubernetes", cakey, cacerts[0], false, false)
if err != nil {
t.Fatal(err)
}
serviceClientkey, serviceClientcerts, err := utiltls.GenerateKeyAndCertificate("system:aro-service", cakey, cacerts[0], false, true)
if err != nil {
t.Fatal(err)
}
serviceKubeconfig, err := testKubeconfig(cacerts, serviceClientkey, serviceClientcerts)
if err != nil {
t.Fatal(err)
}
l := fakeServer(cacerts, serverkey, servercerts)
defer l.Close()
for _, tt := range []struct {
name string
r func(*http.Request)
fixtureChecker func(*testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient)
mocks func(*mock_proxy.MockDialer)
wantStatusCode int
wantBody string
}{
{
name: "success",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient) {
openShiftClusterDocument := &api.OpenShiftClusterDocument{
ID: resourceID,
Key: resourceID,
OpenShiftCluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
NetworkProfile: api.NetworkProfile{
PrivateEndpointIP: privateEndpointIP,
},
AROServiceKubeconfig: api.SecureBytes(serviceKubeconfig),
},
},
}
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
},
mocks: func(dialer *mock_proxy.MockDialer) {
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":6443").Return(l.DialContext(ctx, "", ""))
},
wantStatusCode: http.StatusOK,
wantBody: "GET /test HTTP/1.1\r\nHost: prometheus-k8s-0:9090\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\n\r\n",
},
{
name: "bad path",
r: func(r *http.Request) {
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/prometheus/test"
},
wantStatusCode: http.StatusBadRequest,
wantBody: "Bad Request\n",
},
{
name: "sad database",
fixtureChecker: func(fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient) {
openShiftClustersClient.SetError(fmt.Errorf("sad"))
},
wantStatusCode: http.StatusInternalServerError,
wantBody: "Internal Server Error\n",
},
} {
t.Run(tt.name, func(t *testing.T) {
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
fixture := testdatabase.NewFixture().
WithOpenShiftClusters(dbOpenShiftClusters)
checker := testdatabase.NewChecker()
if tt.fixtureChecker != nil {
tt.fixtureChecker(fixture, checker, openShiftClustersClient)
}
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
r, err := http.NewRequest(http.MethodGet,
"https://localhost:8444"+resourceID+"/prometheus/test", nil)
if err != nil {
panic(err)
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mock_proxy.NewMockDialer(ctrl)
if tt.mocks != nil {
tt.mocks(dialer)
}
aadAuthenticatedRouter := &mux.Router{}
New(logrus.NewEntry(logrus.StandardLogger()), dbOpenShiftClusters, dialer, aadAuthenticatedRouter)
if tt.r != nil {
tt.r(r)
}
w := responsewriter.New(r)
aadAuthenticatedRouter.ServeHTTP(w, r)
resp := w.Response()
if resp.StatusCode != tt.wantStatusCode {
t.Error(resp.StatusCode)
}
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Error(resp.Header.Get("Content-Type"))
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != tt.wantBody {
t.Errorf("%q", string(b))
}
})
}
}
func TestModifyResponse(t *testing.T) {
for _, tt := range []struct {
name string
body string
wantbody string
}{
{
name: "makes absolute a hrefs relative",
body: `<html><head></head><body><a href="/foo"></a></body></html>`,
wantbody: `<html><head></head><body><a href="./foo"></a></body></html>`,
},
{
name: "makes absolute link hrefs relative",
body: `<html><head></head><body><link href="/foo"/></body></html>`,
wantbody: `<html><head></head><body><link href="./foo"/></body></html>`,
},
{
name: "makes absolute script srcs relative",
body: `<html><head></head><body><script src="/foo"></script></body></html>`,
wantbody: `<html><head></head><body><script src="./foo"></script></body></html>`,
},
{
name: "makes PATH_PREFIX variable relative",
body: `<html><head></head><body><script>var PATH_PREFIX = "";</script></body></html>`,
wantbody: `<html><head></head><body><script>var PATH_PREFIX = ".";</script></body></html>`,
},
} {
t.Run(tt.name, func(t *testing.T) {
r := &http.Response{
Header: http.Header{
"Content-Type": []string{"text/html"},
},
Body: ioutil.NopCloser(strings.NewReader(tt.body)),
}
p := &prometheus{}
err := p.modifyResponse(r)
if err != nil {
t.Fatal(err)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != tt.wantbody {
t.Errorf("%q", string(body))
}
length, err := strconv.Atoi(r.Header.Get("Content-Length"))
if err != nil {
t.Fatal(err)
}
if length != len(body) {
t.Error("length mismatch")
}
})
}
}

285
pkg/portal/security_test.go Normal file
Просмотреть файл

@ -0,0 +1,285 @@
package portal
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/util/deployment"
mock_env "github.com/Azure/ARO-RP/pkg/util/mocks/env"
utiltls "github.com/Azure/ARO-RP/pkg/util/tls"
testdatabase "github.com/Azure/ARO-RP/test/database"
"github.com/Azure/ARO-RP/test/util/listener"
)
var (
elevatedGroupIDs = []string{"00000000-0000-0000-0000-000000000000"}
)
func TestSecurity(t *testing.T) {
ctx := context.Background()
log := logrus.NewEntry(logrus.StandardLogger())
controller := gomock.NewController(t)
defer controller.Finish()
_env := mock_env.NewMockCore(controller)
_env.EXPECT().DeploymentMode().AnyTimes().Return(deployment.Production)
_env.EXPECT().Location().AnyTimes().Return("eastus")
_env.EXPECT().TenantID().AnyTimes().Return("00000000-0000-0000-0000-000000000001")
l := listener.NewListener()
defer l.Close()
sshl := listener.NewListener()
defer sshl.Close()
serverkey, servercerts, err := utiltls.GenerateKeyAndCertificate("server", nil, nil, false, false)
if err != nil {
t.Fatal(err)
}
sshkey, _, err := utiltls.GenerateKeyAndCertificate("ssh", nil, nil, false, false)
if err != nil {
t.Fatal(err)
}
dbOpenShiftClusters, _ := testdatabase.NewFakeOpenShiftClusters()
dbPortal, _ := testdatabase.NewFakePortal()
pool := x509.NewCertPool()
pool.AddCert(servercerts[0])
c := &http.Client{
Transport: &http.Transport{
DialContext: l.DialContext,
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
},
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
}
p := NewPortal(_env, log, log, l, sshl, nil, "", serverkey, servercerts, "", nil, nil, make([]byte, 32), sshkey, nil, elevatedGroupIDs, dbOpenShiftClusters, dbPortal, nil)
go func() {
err := p.Run(ctx)
if err != nil {
log.Error(err)
}
}()
for _, tt := range []struct {
name string
request func() (*http.Request, error)
checkResponse func(*testing.T, bool, bool, *http.Response)
unauthenticatedWantStatusCode int
authenticatedWantStatusCode int
}{
{
name: "/",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/", nil)
},
},
{
name: "/index.js",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/index.js", nil)
},
},
{
name: "/api/clusters",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/api/clusters", nil)
},
},
{
name: "/api/logout",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodPost, "https://server/api/logout", nil)
},
authenticatedWantStatusCode: http.StatusSeeOther,
},
{
name: "/callback",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/callback", nil)
},
authenticatedWantStatusCode: http.StatusTemporaryRedirect,
},
{
name: "/healthz/ready",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/healthz/ready", nil)
},
unauthenticatedWantStatusCode: http.StatusOK,
},
{
name: "/kubeconfig/new",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/kubeconfig/new", nil)
},
},
{
name: "/prometheus",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/prometheus", nil)
},
authenticatedWantStatusCode: http.StatusTemporaryRedirect,
},
{
name: "/ssh/new",
request: func() (*http.Request, error) {
req, err := http.NewRequest(http.MethodPost, "https://server/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/resourceGroupName/providers/microsoft.redhatopenshift/openshiftclusters/resourceName/ssh/new", strings.NewReader("{}"))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return req, nil
},
checkResponse: func(t *testing.T, authenticated, elevated bool, resp *http.Response) {
if authenticated && !elevated {
var e struct {
Error string
}
err := json.NewDecoder(resp.Body).Decode(&e)
if err != nil {
t.Fatal(err)
}
if e.Error != "Elevated access is required." {
t.Error(e.Error)
}
}
},
},
{
name: "/doesnotexist",
request: func() (*http.Request, error) {
return http.NewRequest(http.MethodGet, "https://server/doesnotexist", nil)
},
unauthenticatedWantStatusCode: http.StatusNotFound,
authenticatedWantStatusCode: http.StatusNotFound,
},
} {
for _, tt2 := range []struct {
name string
authenticated bool
elevated bool
wantStatusCode int
}{
{
name: "unauthenticated",
wantStatusCode: tt.unauthenticatedWantStatusCode,
},
{
name: "authenticated",
authenticated: true,
wantStatusCode: tt.authenticatedWantStatusCode,
},
{
name: "elevated",
authenticated: true,
elevated: true,
wantStatusCode: tt.authenticatedWantStatusCode,
},
} {
t.Run(tt2.name+tt.name, func(t *testing.T) {
req, err := tt.request()
if err != nil {
t.Fatal(err)
}
err = addCSRF(req)
if err != nil {
t.Fatal(err)
}
if tt2.authenticated {
var groups []string
if tt2.elevated {
groups = elevatedGroupIDs
}
err = addAuth(req, groups)
if err != nil {
t.Fatal(err)
}
}
resp, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if tt2.wantStatusCode == 0 {
if tt2.authenticated {
tt2.wantStatusCode = http.StatusOK
} else {
tt2.wantStatusCode = http.StatusTemporaryRedirect
}
}
if resp.StatusCode != tt2.wantStatusCode {
t.Error(resp.StatusCode)
}
if tt.checkResponse != nil {
tt.checkResponse(t, tt2.authenticated, tt2.elevated, resp)
}
})
}
}
}
func addCSRF(req *http.Request) error {
if req.Method != http.MethodPost {
return nil
}
req.Header.Set("X-CSRF-Token", base64.StdEncoding.EncodeToString(make([]byte, 64)))
sc := securecookie.New(make([]byte, 32), nil)
sc.SetSerializer(securecookie.JSONEncoder{})
cookie, err := sc.Encode("_gorilla_csrf", make([]byte, 32))
if err != nil {
return err
}
req.Header.Add("Cookie", "_gorilla_csrf="+cookie)
return nil
}
func addAuth(req *http.Request, groups []string) error {
store := sessions.NewCookieStore(make([]byte, 32))
cookie, err := securecookie.EncodeMulti(middleware.SessionName, map[interface{}]interface{}{
middleware.SessionKeyUsername: "username",
middleware.SessionKeyGroups: groups,
middleware.SessionKeyExpires: time.Now().Add(time.Hour),
}, store.Codecs...)
if err != nil {
return err
}
req.Header.Add("Cookie", middleware.SessionName+"="+cookie)
return nil
}

334
pkg/portal/ssh/proxy.go Normal file
Просмотреть файл

@ -0,0 +1,334 @@
package ssh
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/x509"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"github.com/Azure/ARO-RP/pkg/api"
utillog "github.com/Azure/ARO-RP/pkg/util/log"
"github.com/Azure/ARO-RP/pkg/util/recover"
)
const (
sshTimeout = time.Hour
)
func (s *ssh) Run() error {
go func() {
defer recover.Panic(s.log)
for {
c, err := s.l.Accept()
if err != nil {
return
}
go func() {
defer recover.Panic(s.log)
err := s.newConn(context.Background(), c)
if err != nil {
s.log.Warn(err)
}
}()
}
}()
return nil
}
func (s *ssh) newConn(ctx context.Context, c1 net.Conn) error {
defer c1.Close()
config := &cryptossh.ServerConfig{}
*config = *s.baseServerConfig
var portalDoc *api.PortalDocument
var connmetadata cryptossh.ConnMetadata
config.PasswordCallback = func(_connmetadata cryptossh.ConnMetadata, pw []byte) (*cryptossh.Permissions, error) {
connmetadata = _connmetadata
password, err := uuid.FromString(string(pw))
if err != nil {
return nil, fmt.Errorf("invalid username") // don't echo password attempt to logs
}
portalDoc, err = s.dbPortal.Get(ctx, password.String())
if err != nil {
return nil, fmt.Errorf("invalid username") // don't echo password attempt to logs
}
if portalDoc.Portal.SSH == nil ||
connmetadata.User() != strings.SplitN(portalDoc.Portal.Username, "@", 2)[0] {
return nil, fmt.Errorf("invalid username")
}
return nil, s.dbPortal.Delete(ctx, portalDoc)
}
conn1, newchannels1, requests1, err := cryptossh.NewServerConn(c1, config)
if err != nil {
var username string
if connmetadata != nil { // after a password attempt
username = connmetadata.User()
}
s.baseAccessLog.WithFields(logrus.Fields{
"remote_addr": c1.RemoteAddr().String(),
"username": username,
}).Warn("authentication failed")
return err
}
accessLog := utillog.EnrichWithPath(s.baseAccessLog, portalDoc.Portal.ID)
accessLog = accessLog.WithFields(logrus.Fields{
"hostname": fmt.Sprintf("master-%d", portalDoc.Portal.SSH.Master),
"remote_addr": c1.RemoteAddr().String(),
"username": portalDoc.Portal.Username,
})
accessLog.Print("authentication succeeded")
openShiftDoc, err := s.dbOpenShiftClusters.Get(ctx, strings.ToLower(portalDoc.Portal.ID))
if err != nil {
return err
}
address := fmt.Sprintf("%s:%d", openShiftDoc.OpenShiftCluster.Properties.NetworkProfile.PrivateEndpointIP, 2200+portalDoc.Portal.SSH.Master)
c2, err := s.dialer.DialContext(ctx, "tcp", address)
if err != nil {
return err
}
defer c2.Close()
key, err := x509.ParsePKCS1PrivateKey(openShiftDoc.OpenShiftCluster.Properties.SSHKey)
if err != nil {
return err
}
signer, err := cryptossh.NewSignerFromSigner(key)
if err != nil {
return err
}
conn2, newchannels2, requests2, err := cryptossh.NewClientConn(c2, "", &cryptossh.ClientConfig{
User: "core",
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
})
if err != nil {
return err
}
t := time.Now()
accessLog.Print("connected")
defer func() {
accessLog.WithFields(logrus.Fields{
"duration": time.Since(t).Seconds(),
}).Print("disconnected")
}()
keyring := agent.NewKeyring()
err = keyring.Add(agent.AddedKey{PrivateKey: key})
if err != nil {
return err
}
return s.proxyConn(accessLog, keyring, conn1, conn2, newchannels1, newchannels2, requests1, requests2)
}
func (s *ssh) proxyConn(accessLog *logrus.Entry, keyring agent.Agent, conn1, conn2 cryptossh.Conn, newchannels1, newchannels2 <-chan cryptossh.NewChannel, requests1, requests2 <-chan *cryptossh.Request) error {
timer := time.NewTimer(sshTimeout)
defer timer.Stop()
var sessionOpened bool
for {
select {
case <-timer.C:
return nil
case nc := <-newchannels1:
if nc == nil {
return nil
}
// on the first c->s session, advertise agent availability
var firstSession bool
if !sessionOpened && nc.ChannelType() == "session" {
firstSession = true
sessionOpened = true
}
go func() {
_ = s.newChannel(accessLog, nc, conn1, conn2, firstSession)
}()
case nc := <-newchannels2:
if nc == nil {
return nil
}
// hijack and handle incoming s->c agent requests
if nc.ChannelType() == "auth-agent@openssh.com" {
go func() {
_ = s.handleAgent(accessLog, nc, keyring)
}()
} else {
go func() {
_ = s.newChannel(accessLog, nc, conn2, conn1, false)
}()
}
case request := <-requests1:
if request == nil {
return nil
}
_ = s.proxyGlobalRequest(request, conn2)
case request := <-requests2:
if request == nil {
return nil
}
_ = s.proxyGlobalRequest(request, conn1)
}
}
}
func (s *ssh) handleAgent(accessLog *logrus.Entry, nc cryptossh.NewChannel, keyring agent.Agent) error {
ch, rs, err := nc.Accept()
if err != nil {
return err
}
defer ch.Close()
channelLog := accessLog.WithFields(logrus.Fields{
"channel": nc.ChannelType(),
})
channelLog.Printf("opened")
defer channelLog.Printf("closed")
go cryptossh.DiscardRequests(rs)
return agent.ServeAgent(keyring, ch)
}
func (s *ssh) newChannel(accessLog *logrus.Entry, nc cryptossh.NewChannel, conn1, conn2 cryptossh.Conn, firstSession bool) error {
defer recover.Panic(s.log)
ch2, rs2, err := conn2.OpenChannel(nc.ChannelType(), nc.ExtraData())
if err, ok := err.(*cryptossh.OpenChannelError); ok {
return nc.Reject(err.Reason, err.Message)
} else if err != nil {
return err
}
ch1, rs1, err := nc.Accept()
if err != nil {
return err
}
channelLog := accessLog.WithFields(logrus.Fields{
"channel": nc.ChannelType(),
})
channelLog.Printf("opened")
defer channelLog.Printf("closed")
if firstSession {
_, err = ch2.SendRequest("auth-agent-req@openssh.com", true, nil)
if err != nil {
return err
}
}
return s.proxyChannel(ch1, ch2, rs1, rs2)
}
func (s *ssh) proxyGlobalRequest(r *cryptossh.Request, c cryptossh.Conn) error {
ok, payload, err := c.SendRequest(r.Type, r.WantReply, r.Payload)
if err != nil {
return err
}
return r.Reply(ok, payload)
}
func (s *ssh) proxyRequest(r *cryptossh.Request, ch cryptossh.Channel) error {
ok, err := ch.SendRequest(r.Type, r.WantReply, r.Payload)
if err != nil {
return err
}
return r.Reply(ok, nil)
}
func (s *ssh) proxyChannel(ch1, ch2 cryptossh.Channel, rs1, rs2 <-chan *cryptossh.Request) error {
var wg sync.WaitGroup
wg.Add(4)
go func() {
defer recover.Panic(s.log)
defer wg.Done()
_, _ = io.Copy(ch1, ch2)
_ = ch1.CloseWrite()
}()
go func() {
defer recover.Panic(s.log)
defer wg.Done()
_, _ = io.Copy(ch2, ch1)
_ = ch2.CloseWrite()
}()
go func() {
defer recover.Panic(s.log)
defer wg.Done()
for r := range rs1 {
err := s.proxyRequest(r, ch2)
if err != nil {
break
}
}
_ = ch2.Close()
}()
go func() {
defer recover.Panic(s.log)
defer wg.Done()
for r := range rs2 {
err := s.proxyRequest(r, ch1)
if err != nil {
break
}
}
_ = ch1.Close()
}()
wg.Wait()
return nil
}

Просмотреть файл

@ -0,0 +1,443 @@
package ssh
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"fmt"
"net"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/onsi/gomega"
"github.com/onsi/gomega/types"
"github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
mock_proxy "github.com/Azure/ARO-RP/pkg/util/mocks/proxy"
"github.com/Azure/ARO-RP/pkg/util/tls"
testdatabase "github.com/Azure/ARO-RP/test/database"
"github.com/Azure/ARO-RP/test/util/bufferedpipe"
"github.com/Azure/ARO-RP/test/util/listener"
testlog "github.com/Azure/ARO-RP/test/util/log"
)
// fakeClient runs a fake client on the given connection. It validates the
// server key, authenticates, writes a ping request, reads a pong reply, then
// closes the connection
func fakeClient(c net.Conn, serverKey *rsa.PublicKey, user string, password string) error {
publicKey, err := cryptossh.NewPublicKey(serverKey)
if err != nil {
return err
}
conn, _, _, err := cryptossh.NewClientConn(c, "", &cryptossh.ClientConfig{
HostKeyCallback: cryptossh.FixedHostKey(publicKey),
User: user,
Auth: []cryptossh.AuthMethod{
cryptossh.Password(password),
},
})
if err != nil {
return err
}
_, reply, err := conn.SendRequest("ping", true, []byte("ping"))
if err != nil {
return err
}
if string(reply) != "pong" {
return fmt.Errorf("invalid reply %q", string(reply))
}
return conn.Close()
}
// fakeServer returns a test listener for an SSH server which validates the
// client key, reads ping request(s) and writes pong replies
func fakeServer(clientKey *rsa.PublicKey) (*listener.Listener, error) {
l := listener.NewListener()
clientPublicKey, err := cryptossh.NewPublicKey(clientKey)
if err != nil {
return nil, err
}
config := &cryptossh.ServerConfig{
PublicKeyCallback: func(conn cryptossh.ConnMetadata, key cryptossh.PublicKey) (*cryptossh.Permissions, error) {
if conn.User() != "core" {
return nil, fmt.Errorf("invalid user")
}
if !bytes.Equal(key.Marshal(), clientPublicKey.Marshal()) {
return nil, fmt.Errorf("invalid key")
}
return nil, nil
},
}
key, _, err := tls.GenerateKeyAndCertificate("server", nil, nil, false, false)
if err != nil {
return nil, err
}
signer, err := cryptossh.NewSignerFromSigner(key)
if err != nil {
return nil, err
}
config.AddHostKey(signer)
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
go func() {
conn, _, requests, err := cryptossh.NewServerConn(c, config)
if err != nil {
return
}
go func() {
for request := range requests {
if request.Type == "ping" && request.WantReply {
err := request.Reply(true, []byte("pong"))
if err != nil {
break
}
} else {
err := request.Reply(false, nil)
if err != nil {
break
}
}
}
}()
_ = conn.Wait()
}()
}
}()
return l, nil
}
func TestProxy(t *testing.T) {
ctx := context.Background()
username := "test"
password := "00000000-0000-0000-0000-000000000000"
subscriptionID := "10000000-0000-0000-0000-000000000000"
resourceGroup := "rg"
resourceName := "cluster"
resourceID := "/subscriptions/" + subscriptionID + "/resourcegroups/" + resourceGroup + "/providers/microsoft.redhatopenshift/openshiftclusters/" + resourceName
privateEndpointIP := "1.2.3.4"
hostKey, _, err := tls.GenerateKeyAndCertificate("proxy", nil, nil, false, false)
if err != nil {
t.Fatal(err)
}
clusterKey, _, err := tls.GenerateKeyAndCertificate("cluster", nil, nil, false, false)
if err != nil {
t.Fatal(err)
}
l, err := fakeServer(&clusterKey.PublicKey)
if err != nil {
t.Fatal(err)
}
defer l.Close()
goodOpenShiftClusterDocument := func() *api.OpenShiftClusterDocument {
return &api.OpenShiftClusterDocument{
ID: resourceID,
Key: resourceID,
OpenShiftCluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
NetworkProfile: api.NetworkProfile{
PrivateEndpointIP: privateEndpointIP,
},
SSHKey: api.SecureBytes(x509.MarshalPKCS1PrivateKey(clusterKey)),
},
},
}
}
goodPortalDocument := func(id string) *api.PortalDocument {
return &api.PortalDocument{
ID: id,
Portal: &api.Portal{
ID: resourceID,
Username: username,
SSH: &api.SSH{
Master: 1,
},
},
}
}
type test struct {
name string
username string
password string
fixtureChecker func(*test, *testdatabase.Fixture, *testdatabase.Checker, *cosmosdb.FakeOpenShiftClusterDocumentClient, *cosmosdb.FakePortalDocumentClient)
mocks func(*mock_proxy.MockDialer)
wantErrPrefix string
wantLogs []map[string]types.GomegaMatcher
}
for _, tt := range []*test{
{
name: "good",
username: username,
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := goodPortalDocument(tt.password)
fixture.AddPortalDocuments(portalDocument)
openShiftClusterDocument := goodOpenShiftClusterDocument()
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
},
mocks: func(dialer *mock_proxy.MockDialer) {
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":2201").Return(l.DialContext(ctx, "", ""))
},
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.InfoLevel),
"msg": gomega.Equal("authentication succeeded"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
{
"level": gomega.Equal(logrus.InfoLevel),
"msg": gomega.Equal("connected"),
"hostname": gomega.Equal("master-1"),
"resource_group": gomega.Equal(resourceGroup),
"resource_id": gomega.Equal(resourceID),
"resource_name": gomega.Equal(resourceName),
"subscription_id": gomega.Equal(subscriptionID),
"username": gomega.Equal(username),
},
{
"level": gomega.Equal(logrus.InfoLevel),
"msg": gomega.Equal("disconnected"),
"duration": gomega.BeNumerically(">", 0),
"hostname": gomega.Equal("master-1"),
"resource_group": gomega.Equal(resourceGroup),
"resource_id": gomega.Equal(resourceID),
"resource_name": gomega.Equal(resourceName),
"subscription_id": gomega.Equal(subscriptionID),
"username": gomega.Equal(username),
},
},
},
{
name: "bad username",
username: "bad",
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := goodPortalDocument(tt.password)
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
},
wantErrPrefix: "ssh: handshake failed",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.WarnLevel),
"msg": gomega.Equal("authentication failed"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal("bad"),
},
},
},
{
name: "bad password, not uuid",
username: username,
password: "bad",
wantErrPrefix: "ssh: handshake failed",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.WarnLevel),
"msg": gomega.Equal("authentication failed"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
{
name: "bad password",
username: username,
password: password,
wantErrPrefix: "ssh: handshake failed",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.WarnLevel),
"msg": gomega.Equal("authentication failed"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
{
name: "not ssh record",
username: username,
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := goodPortalDocument(tt.password)
portalDocument.Portal.SSH = nil
fixture.AddPortalDocuments(portalDocument)
checker.AddPortalDocuments(portalDocument)
},
wantErrPrefix: "ssh: handshake failed",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.WarnLevel),
"msg": gomega.Equal("authentication failed"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
{
name: "sad openshiftClusters database",
username: username,
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := goodPortalDocument(tt.password)
fixture.AddPortalDocuments(portalDocument)
openShiftClustersClient.SetError(fmt.Errorf("sad"))
},
wantErrPrefix: "EOF",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.InfoLevel),
"msg": gomega.Equal("authentication succeeded"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
{
name: "sad portal database",
username: username,
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalClient.SetError(fmt.Errorf("sad"))
},
wantErrPrefix: "ssh: handshake failed",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.WarnLevel),
"msg": gomega.Equal("authentication failed"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
{
name: "sad dialer",
username: username,
password: password,
fixtureChecker: func(tt *test, fixture *testdatabase.Fixture, checker *testdatabase.Checker, openShiftClustersClient *cosmosdb.FakeOpenShiftClusterDocumentClient, portalClient *cosmosdb.FakePortalDocumentClient) {
portalDocument := goodPortalDocument(tt.password)
fixture.AddPortalDocuments(portalDocument)
openShiftClusterDocument := goodOpenShiftClusterDocument()
fixture.AddOpenShiftClusterDocuments(openShiftClusterDocument)
checker.AddOpenShiftClusterDocuments(openShiftClusterDocument)
},
mocks: func(dialer *mock_proxy.MockDialer) {
dialer.EXPECT().DialContext(gomock.Any(), "tcp", privateEndpointIP+":2201").Return(nil, fmt.Errorf("sad"))
},
wantErrPrefix: "EOF",
wantLogs: []map[string]types.GomegaMatcher{
{
"level": gomega.Equal(logrus.InfoLevel),
"msg": gomega.Equal("authentication succeeded"),
"remote_addr": gomega.Not(gomega.BeEmpty()),
"username": gomega.Equal(username),
},
},
},
} {
t.Run(tt.name, func(t *testing.T) {
dbPortal, portalClient := testdatabase.NewFakePortal()
dbOpenShiftClusters, openShiftClustersClient := testdatabase.NewFakeOpenShiftClusters()
fixture := testdatabase.NewFixture().
WithOpenShiftClusters(dbOpenShiftClusters).
WithPortal(dbPortal)
checker := testdatabase.NewChecker()
if tt.fixtureChecker != nil {
tt.fixtureChecker(tt, fixture, checker, openShiftClustersClient, portalClient)
}
err := fixture.Create()
if err != nil {
t.Fatal(err)
}
client, client1 := bufferedpipe.New()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mock_proxy.NewMockDialer(ctrl)
if tt.mocks != nil {
tt.mocks(dialer)
}
hook, log := testlog.New()
s, err := New(nil, nil, log, nil, hostKey, nil, dbOpenShiftClusters, dbPortal, dialer, &mux.Router{})
if err != nil {
t.Fatal(err)
}
done := make(chan struct{})
go func() {
_ = s.newConn(context.Background(), client1)
close(done)
}()
err = fakeClient(client, &hostKey.PublicKey, tt.username, tt.password)
if err != nil && !strings.HasPrefix(err.Error(), tt.wantErrPrefix) ||
err == nil && tt.wantErrPrefix != "" {
t.Error(err)
}
<-done
openShiftClustersClient.SetError(nil)
portalClient.SetError(nil)
for _, err = range checker.CheckOpenShiftClusters(openShiftClustersClient) {
t.Error(err)
}
for _, err = range checker.CheckPortals(portalClient) {
t.Error(err)
}
err = testlog.AssertLoggingOutput(hook, tt.wantLogs)
if err != nil {
t.Error(err)
}
})
}
}

193
pkg/portal/ssh/ssh.go Normal file
Просмотреть файл

@ -0,0 +1,193 @@
package ssh
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"crypto/rsa"
"encoding/json"
"fmt"
"mime"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
uuid "github.com/satori/go.uuid"
"github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/api/validate"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/deployment"
)
const (
sshNewTimeout = time.Minute
)
type ssh struct {
env env.Core
log *logrus.Entry
baseAccessLog *logrus.Entry
l net.Listener
elevatedGroupIDs []string
dbOpenShiftClusters database.OpenShiftClusters
dbPortal database.Portal
dialer proxy.Dialer
baseServerConfig *cryptossh.ServerConfig
newPassword func() string
}
func New(env env.Core,
log *logrus.Entry,
baseAccessLog *logrus.Entry,
l net.Listener,
hostKey *rsa.PrivateKey,
elevatedGroupIDs []string,
dbOpenShiftClusters database.OpenShiftClusters,
dbPortal database.Portal,
dialer proxy.Dialer,
aadAuthenticatedRouter *mux.Router) (*ssh, error) {
s := &ssh{
env: env,
log: log,
baseAccessLog: baseAccessLog,
l: l,
elevatedGroupIDs: elevatedGroupIDs,
dbOpenShiftClusters: dbOpenShiftClusters,
dbPortal: dbPortal,
dialer: dialer,
baseServerConfig: &cryptossh.ServerConfig{},
newPassword: func() string { return uuid.NewV4().String() },
}
signer, err := cryptossh.NewSignerFromSigner(hostKey)
if err != nil {
return nil, err
}
s.baseServerConfig.AddHostKey(signer)
aadAuthenticatedRouter.NewRoute().Methods(http.MethodPost).Path("/subscriptions/{subscriptionId}/resourcegroups/{resourceGroupName}/providers/microsoft.redhatopenshift/openshiftclusters/{resourceName}/ssh/new").HandlerFunc(s.new)
return s, nil
}
type request struct {
Master int `json:"master,omitempty"`
}
type response struct {
Command string `json:"command,omitempty"`
Password string `json:"password,omitempty"`
Error string `json:"error,omitempty"`
}
func (s *ssh) new(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
parts := strings.Split(r.URL.Path, "/")
if len(parts) < 9 {
http.Error(w, "invalid resourceId", http.StatusBadRequest)
return
}
resourceID := strings.Join(parts[:9], "/")
if !validate.RxClusterID.MatchString(resourceID) {
http.Error(w, fmt.Sprintf("invalid resourceId %q", resourceID), http.StatusBadRequest)
return
}
mediatype, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
if mediatype != "application/json" {
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
return
}
var req *request
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil || req.Master < 0 || req.Master > 2 {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
elevated := middleware.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))
if !elevated {
s.sendResponse(w, &response{
Error: "Elevated access is required.",
})
return
}
username := r.Context().Value(middleware.ContextKeyUsername).(string)
username = strings.SplitN(username, "@", 2)[0]
password := s.newPassword()
portalDoc := &api.PortalDocument{
ID: password,
TTL: int(sshNewTimeout / time.Second),
Portal: &api.Portal{
Username: ctx.Value(middleware.ContextKeyUsername).(string),
ID: resourceID,
SSH: &api.SSH{
Master: req.Master,
},
},
}
_, err = s.dbPortal.Create(ctx, portalDoc)
if err != nil {
s.internalServerError(w, err)
return
}
host := r.Host
if strings.ContainsRune(r.Host, ':') {
host, _, err = net.SplitHostPort(r.Host)
if err != nil {
s.internalServerError(w, err)
return
}
}
port := ""
if s.env.DeploymentMode() == deployment.Development {
port = "-p 2222 "
}
s.sendResponse(w, &response{
Command: fmt.Sprintf("ssh %s%s@%s", port, username, host),
Password: password,
})
}
func (s *ssh) sendResponse(w http.ResponseWriter, resp *response) {
b, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.internalServerError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(b)
}
func (s *ssh) internalServerError(w http.ResponseWriter, err error) {
s.log.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

190
pkg/portal/ssh/ssh_test.go Normal file
Просмотреть файл

@ -0,0 +1,190 @@
package ssh
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database/cosmosdb"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/portal/util/responsewriter"
"github.com/Azure/ARO-RP/pkg/util/deployment"
mock_env "github.com/Azure/ARO-RP/pkg/util/mocks/env"
"github.com/Azure/ARO-RP/pkg/util/tls"
testdatabase "github.com/Azure/ARO-RP/test/database"
)
func TestNew(t *testing.T) {
resourceID := "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster"
elevatedGroupIDs := []string{"10000000-0000-0000-0000-000000000000"}
username := "username"
password := "password"
master := 0
hostKey, _, err := tls.GenerateKeyAndCertificate("proxy", nil, nil, false, false)
if err != nil {
t.Fatal(err)
}
for _, tt := range []struct {
name string
deploymentMode deployment.Mode
r func(*http.Request)
checker func(*testdatabase.Checker, *cosmosdb.FakePortalDocumentClient)
wantStatusCode int
wantBody string
}{
{
name: "success",
checker: func(checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
checker.AddPortalDocuments(&api.PortalDocument{
ID: password,
TTL: 60,
Portal: &api.Portal{
Username: username,
ID: resourceID,
SSH: &api.SSH{
Master: master,
},
},
})
},
wantStatusCode: http.StatusOK,
wantBody: "{\n \"command\": \"ssh username@localhost\",\n \"password\": \"password\"\n}",
},
{
name: "bad path",
r: func(r *http.Request) {
r.URL.Path = "/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster/ssh/new"
},
wantStatusCode: http.StatusBadRequest,
wantBody: "invalid resourceId \"/subscriptions/BAD/resourcegroups/rg/providers/microsoft.redhatopenshift/openshiftclusters/cluster\"\n",
},
{
name: "bad content type",
r: func(r *http.Request) {
r.Header.Set("Content-Type", "bad")
},
wantStatusCode: http.StatusUnsupportedMediaType,
wantBody: "Unsupported Media Type\n",
},
{
name: "empty request",
r: func(r *http.Request) {
r.Body = ioutil.NopCloser(bytes.NewReader(nil))
},
wantStatusCode: http.StatusBadRequest,
wantBody: "Bad Request\n",
},
{
name: "junk request",
r: func(r *http.Request) {
r.Body = ioutil.NopCloser(strings.NewReader("{{"))
},
wantStatusCode: http.StatusBadRequest,
wantBody: "Bad Request\n",
},
{
name: "not elevated",
r: func(r *http.Request) {
*r = *r.WithContext(context.WithValue(r.Context(), middleware.ContextKeyGroups, []string{}))
},
wantStatusCode: http.StatusOK,
wantBody: "{\n \"error\": \"Elevated access is required.\"\n}",
},
{
name: "sad database",
checker: func(checker *testdatabase.Checker, portalClient *cosmosdb.FakePortalDocumentClient) {
portalClient.SetError(fmt.Errorf("sad"))
},
wantStatusCode: http.StatusInternalServerError,
wantBody: "Internal Server Error\n",
},
} {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
dbPortal, portalClient := testdatabase.NewFakePortal()
checker := testdatabase.NewChecker()
if tt.checker != nil {
tt.checker(checker, portalClient)
}
ctx = context.WithValue(ctx, middleware.ContextKeyUsername, username)
ctx = context.WithValue(ctx, middleware.ContextKeyGroups, elevatedGroupIDs)
r, err := http.NewRequestWithContext(ctx, http.MethodPost,
"https://localhost:8444"+resourceID+"/ssh/new", strings.NewReader(fmt.Sprintf(`{"master":%d}`, master)))
if err != nil {
panic(err)
}
r.Header.Set("Content-Type", "application/json")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
env := mock_env.NewMockCore(ctrl)
env.EXPECT().DeploymentMode().AnyTimes().Return(tt.deploymentMode)
aadAuthenticatedRouter := &mux.Router{}
s, err := New(env, logrus.NewEntry(logrus.StandardLogger()), nil, nil, hostKey, elevatedGroupIDs, nil, dbPortal, nil, aadAuthenticatedRouter)
if err != nil {
t.Fatal(err)
}
s.newPassword = func() string { return password }
if tt.r != nil {
tt.r(r)
}
w := responsewriter.New(r)
aadAuthenticatedRouter.ServeHTTP(w, r)
portalClient.SetError(nil)
for _, err = range checker.CheckPortals(portalClient) {
t.Error(err)
}
resp := w.Response()
if resp.StatusCode != tt.wantStatusCode {
t.Error(resp.StatusCode)
}
wantContentType := "application/json"
if resp.StatusCode != http.StatusOK {
wantContentType = "text/plain; charset=utf-8"
}
if resp.Header.Get("Content-Type") != wantContentType {
t.Error(resp.Header.Get("Content-Type"))
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != tt.wantBody {
t.Errorf("%q", string(b))
}
})
}
}

Просмотреть файл

@ -0,0 +1,76 @@
package clientcache
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
"sync"
"time"
)
// ClientCache is a cache for *http.Clients. It allows us to reuse clients and
// connections across multiple incoming calls, saving us TCP, TLS and proxy
// initialisations.
type ClientCache interface {
Get(interface{}) *http.Client
Put(interface{}, *http.Client)
}
type clientCache struct {
mu sync.Mutex
now func() time.Time
ttl time.Duration
m map[interface{}]*v
}
type v struct {
expires time.Time
cli *http.Client
}
// New returns a new ClientCache
func New(ttl time.Duration) ClientCache {
return &clientCache{
now: time.Now,
ttl: ttl,
m: map[interface{}]*v{},
}
}
// call holding c.mu
func (c *clientCache) expire() {
now := c.now()
for k, v := range c.m {
if now.After(v.expires) {
v.cli.CloseIdleConnections()
delete(c.m, k)
}
}
}
func (c *clientCache) Get(k interface{}) (cli *http.Client) {
c.mu.Lock()
defer c.mu.Unlock()
if v := c.m[k]; v != nil {
v.expires = c.now().Add(c.ttl)
cli = v.cli
}
c.expire()
return
}
func (c *clientCache) Put(k interface{}, cli *http.Client) {
c.mu.Lock()
defer c.mu.Unlock()
c.m[k] = &v{
expires: c.now().Add(c.ttl),
cli: cli,
}
c.expire()
}

Просмотреть файл

@ -0,0 +1,54 @@
package clientcache
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
"testing"
"time"
)
func TestClientCache(t *testing.T) {
var now time.Time
cli1 := &http.Client{}
cli2 := &http.Client{}
c := New(1)
c.(*clientCache).now = func() time.Time { return now }
// t = 0: put(1), get(1)
c.Put(1, cli1)
if c.Get(1) != cli1 {
t.Error(c.Get(1), cli1)
}
now = now.Add(2)
// t = 2: put(1), get(1) (cli1's ttl should be reset before expiring)
if c.Get(1) != cli1 {
t.Error(c.Get(1), cli1)
}
now = now.Add(2)
// t = 4: put(2), get(2) (cli1 should be expired)
c.Put(2, cli2)
if c.Get(2) != cli2 {
t.Error(c.Get(2), cli2)
}
if c.Get(1) != nil {
t.Error(c.Get(1))
}
now = now.Add(2)
// t = 6: cli2 should be expired
c.(*clientCache).expire()
if c.Get(2) != nil {
t.Error(c.Get(2))
}
}

Просмотреть файл

@ -0,0 +1,51 @@
package responsewriter
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"io/ioutil"
"net/http"
)
// ResponseWriter represents a ResponseWriter
type ResponseWriter interface {
http.ResponseWriter
Response() *http.Response
}
type responseWriter struct {
bytes.Buffer
r *http.Request
h http.Header
statusCode int
}
// New returns an http.ResponseWriter on which you can later call Response() to
// generate an *http.Response.
func New(r *http.Request) ResponseWriter {
return &responseWriter{
r: r,
h: http.Header{},
statusCode: http.StatusOK,
}
}
func (w *responseWriter) Header() http.Header {
return w.h
}
func (w *responseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *responseWriter) Response() *http.Response {
return &http.Response{
ProtoMajor: w.r.ProtoMajor,
ProtoMinor: w.r.ProtoMinor,
StatusCode: w.statusCode,
Header: w.h,
Body: ioutil.NopCloser(&w.Buffer),
}
}

Просмотреть файл

@ -0,0 +1,23 @@
package responsewriter
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"net/http"
"testing"
)
func TestResponseWriter(t *testing.T) {
w := New(&http.Request{ProtoMajor: 1, ProtoMinor: 1})
http.NotFound(w, nil)
buf := &bytes.Buffer{}
_ = w.Response().Write(buf)
if buf.String() != "HTTP/1.1 404 Not Found\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\nX-Content-Type-Options: nosniff\r\n\r\n404 page not found\n" {
t.Error(buf.String())
}
}

8
pkg/proxy/generate.go Normal file
Просмотреть файл

@ -0,0 +1,8 @@
package proxy
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
//go:generate rm -rf ../util/mocks/$GOPACKAGE
//go:generate go run ../../vendor/github.com/golang/mock/mockgen -destination=../util/mocks/$GOPACKAGE/$GOPACKAGE.go github.com/Azure/ARO-RP/pkg/$GOPACKAGE Dialer
//go:generate go run ../../vendor/golang.org/x/tools/cmd/goimports -local=github.com/Azure/ARO-RP -e -w ../util/mocks/$GOPACKAGE/$GOPACKAGE.go

124
pkg/util/mocks/env/env.go поставляемый
Просмотреть файл

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/Azure/ARO-RP/pkg/env (interfaces: Interface)
// Source: github.com/Azure/ARO-RP/pkg/env (interfaces: Core,Interface)
// Package mock_env is a generated GoMock package.
package mock_env
@ -21,6 +21,128 @@ import (
refreshable "github.com/Azure/ARO-RP/pkg/util/refreshable"
)
// MockCore is a mock of Core interface
type MockCore struct {
ctrl *gomock.Controller
recorder *MockCoreMockRecorder
}
// MockCoreMockRecorder is the mock recorder for MockCore
type MockCoreMockRecorder struct {
mock *MockCore
}
// NewMockCore creates a new mock instance
func NewMockCore(ctrl *gomock.Controller) *MockCore {
mock := &MockCore{ctrl: ctrl}
mock.recorder = &MockCoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCore) EXPECT() *MockCoreMockRecorder {
return m.recorder
}
// DeploymentMode mocks base method
func (m *MockCore) DeploymentMode() deployment.Mode {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeploymentMode")
ret0, _ := ret[0].(deployment.Mode)
return ret0
}
// DeploymentMode indicates an expected call of DeploymentMode
func (mr *MockCoreMockRecorder) DeploymentMode() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeploymentMode", reflect.TypeOf((*MockCore)(nil).DeploymentMode))
}
// Environment mocks base method
func (m *MockCore) Environment() *azure.Environment {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Environment")
ret0, _ := ret[0].(*azure.Environment)
return ret0
}
// Environment indicates an expected call of Environment
func (mr *MockCoreMockRecorder) Environment() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Environment", reflect.TypeOf((*MockCore)(nil).Environment))
}
// Location mocks base method
func (m *MockCore) Location() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Location")
ret0, _ := ret[0].(string)
return ret0
}
// Location indicates an expected call of Location
func (mr *MockCoreMockRecorder) Location() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Location", reflect.TypeOf((*MockCore)(nil).Location))
}
// NewRPAuthorizer mocks base method
func (m *MockCore) NewRPAuthorizer(arg0 string) (autorest.Authorizer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewRPAuthorizer", arg0)
ret0, _ := ret[0].(autorest.Authorizer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NewRPAuthorizer indicates an expected call of NewRPAuthorizer
func (mr *MockCoreMockRecorder) NewRPAuthorizer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRPAuthorizer", reflect.TypeOf((*MockCore)(nil).NewRPAuthorizer), arg0)
}
// ResourceGroup mocks base method
func (m *MockCore) ResourceGroup() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResourceGroup")
ret0, _ := ret[0].(string)
return ret0
}
// ResourceGroup indicates an expected call of ResourceGroup
func (mr *MockCoreMockRecorder) ResourceGroup() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceGroup", reflect.TypeOf((*MockCore)(nil).ResourceGroup))
}
// SubscriptionID mocks base method
func (m *MockCore) SubscriptionID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubscriptionID")
ret0, _ := ret[0].(string)
return ret0
}
// SubscriptionID indicates an expected call of SubscriptionID
func (mr *MockCoreMockRecorder) SubscriptionID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionID", reflect.TypeOf((*MockCore)(nil).SubscriptionID))
}
// TenantID mocks base method
func (m *MockCore) TenantID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TenantID")
ret0, _ := ret[0].(string)
return ret0
}
// TenantID indicates an expected call of TenantID
func (mr *MockCoreMockRecorder) TenantID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TenantID", reflect.TypeOf((*MockCore)(nil).TenantID))
}
// MockInterface is a mock of Interface interface
type MockInterface struct {
ctrl *gomock.Controller

Просмотреть файл

@ -0,0 +1,51 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/Azure/ARO-RP/pkg/proxy (interfaces: Dialer)
// Package mock_proxy is a generated GoMock package.
package mock_proxy
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockDialer is a mock of Dialer interface
type MockDialer struct {
ctrl *gomock.Controller
recorder *MockDialerMockRecorder
}
// MockDialerMockRecorder is the mock recorder for MockDialer
type MockDialerMockRecorder struct {
mock *MockDialer
}
// NewMockDialer creates a new mock instance
func NewMockDialer(ctrl *gomock.Controller) *MockDialer {
mock := &MockDialer{ctrl: ctrl}
mock.recorder = &MockDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockDialer) EXPECT() *MockDialerMockRecorder {
return m.recorder
}
// DialContext mocks base method
func (m *MockDialer) DialContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialContext", arg0, arg1, arg2)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialContext indicates an expected call of DialContext
func (mr *MockDialerMockRecorder) DialContext(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockDialer)(nil).DialContext), arg0, arg1, arg2)
}

Просмотреть файл

@ -40,7 +40,13 @@ func RestConfig(dialer proxy.Dialer, oc *api.OpenShiftCluster) (*rest.Config, er
return nil, err
}
restconfig.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
restconfig.Dial = DialContext(dialer, oc)
return restconfig, nil
}
func DialContext(dialer proxy.Dialer, oc *api.OpenShiftCluster) func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" {
return nil, fmt.Errorf("unimplemented network %q", network)
}
@ -52,6 +58,4 @@ func RestConfig(dialer proxy.Dialer, oc *api.OpenShiftCluster) (*rest.Config, er
return dialer.DialContext(ctx, network, oc.Properties.NetworkProfile.PrivateEndpointIP+":"+port)
}
return restconfig, nil
}

Просмотреть файл

@ -0,0 +1,14 @@
package roundtripper
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"net/http"
)
type RoundTripperFunc func(*http.Request) (*http.Response, error)
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

Просмотреть файл

@ -21,6 +21,7 @@ type Checker struct {
subscriptionDocuments []*api.SubscriptionDocument
billingDocuments []*api.BillingDocument
asyncOperationDocuments []*api.AsyncOperationDocument
portalDocuments []*api.PortalDocument
}
func NewChecker() *Checker {
@ -43,6 +44,10 @@ func (f *Checker) AddAsyncOperationDocuments(docs ...*api.AsyncOperationDocument
f.asyncOperationDocuments = append(f.asyncOperationDocuments, docs...)
}
func (f *Checker) AddPortalDocuments(docs ...*api.PortalDocument) {
f.portalDocuments = append(f.portalDocuments, docs...)
}
func (f *Checker) CheckOpenShiftClusters(openShiftClusters *cosmosdb.FakeOpenShiftClusterDocumentClient) (errs []error) {
ctx := context.Background()
@ -134,3 +139,23 @@ func (f *Checker) CheckAsyncOperations(asyncOperations *cosmosdb.FakeAsyncOperat
return errs
}
func (f *Checker) CheckPortals(portals *cosmosdb.FakePortalDocumentClient) (errs []error) {
ctx := context.Background()
all, err := portals.ListAll(ctx, nil)
if err != nil {
return []error{err}
}
if len(f.portalDocuments) != 0 && len(all.PortalDocuments) == len(f.portalDocuments) {
diff := deep.Equal(all.PortalDocuments, f.portalDocuments)
for _, i := range diff {
errs = append(errs, errors.New(i))
}
} else if len(all.PortalDocuments) != 0 || len(f.portalDocuments) != 0 {
errs = append(errs, fmt.Errorf("portals length different, %d vs %d", len(all.PortalDocuments), len(f.portalDocuments)))
}
return errs
}

Просмотреть файл

@ -17,11 +17,13 @@ type Fixture struct {
subscriptionDocuments []*api.SubscriptionDocument
billingDocuments []*api.BillingDocument
asyncOperationDocuments []*api.AsyncOperationDocument
portalDocuments []*api.PortalDocument
openShiftClustersDatabase database.OpenShiftClusters
billingDatabase database.Billing
subscriptionsDatabase database.Subscriptions
asyncOperationsDatabase database.AsyncOperations
portalDatabase database.Portal
}
func NewFixture() *Fixture {
@ -48,6 +50,11 @@ func (f *Fixture) WithAsyncOperations(db database.AsyncOperations) *Fixture {
return f
}
func (f *Fixture) WithPortal(db database.Portal) *Fixture {
f.portalDatabase = db
return f
}
func (f *Fixture) AddOpenShiftClusterDocuments(docs ...*api.OpenShiftClusterDocument) {
f.openshiftClusterDocuments = append(f.openshiftClusterDocuments, docs...)
}
@ -64,6 +71,10 @@ func (f *Fixture) AddAsyncOperationDocuments(docs ...*api.AsyncOperationDocument
f.asyncOperationDocuments = append(f.asyncOperationDocuments, docs...)
}
func (f *Fixture) AddPortalDocuments(docs ...*api.PortalDocument) {
f.portalDocuments = append(f.portalDocuments, docs...)
}
func (f *Fixture) Create() error {
ctx := context.Background()
@ -98,5 +109,12 @@ func (f *Fixture) Create() error {
}
}
for _, i := range f.portalDocuments {
_, err := f.portalDatabase.Create(ctx, i)
if err != nil {
return err
}
}
return nil
}

Просмотреть файл

@ -47,3 +47,9 @@ func NewFakeAsyncOperations() (db database.AsyncOperations, client *cosmosdb.Fak
db = database.NewAsyncOperationsWithProvidedClient(client)
return db, client
}
func NewFakePortal() (db database.Portal, client *cosmosdb.FakePortalDocumentClient) {
client = cosmosdb.NewFakePortalDocumentClient(jsonHandle)
db = database.NewPortalWithProvidedClient(client)
return db, client
}

Просмотреть файл

@ -0,0 +1,94 @@
package bufferedpipe
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"errors"
"io"
"net"
"sync"
"time"
)
// New returns two net.Conns representing either side of a buffered pipe. It's
// like net.Pipe() but with buffering. Note that deadlines are currently not
// implemented.
func New() (net.Conn, net.Conn) {
p := &p{
cond: sync.NewCond(&sync.Mutex{}),
}
return &conn{p, 0}, &conn{p, 1}
}
type p struct {
cond *sync.Cond
buf [2]bytes.Buffer
closed [2]bool
}
type conn struct {
p *p
n int
}
func (c *conn) Read(b []byte) (int, error) {
c.p.cond.L.Lock()
defer c.p.cond.L.Unlock()
for {
if c.p.closed[c.n] {
// Read() concurrently with, or after Close()
return 0, errors.New("connection closed")
}
if c.p.buf[c.n^1].Len() > 0 {
return c.p.buf[c.n^1].Read(b)
}
if c.p.closed[c.n^1] {
// Other side closed and read buffer is drained
return 0, io.EOF
}
c.p.cond.Wait()
}
}
func (c *conn) Write(b []byte) (int, error) {
c.p.cond.L.Lock()
defer c.p.cond.L.Unlock()
if c.p.closed[c.n] {
return 0, errors.New("connection closed")
}
c.p.cond.Broadcast()
return c.p.buf[c.n].Write(b)
}
func (c *conn) Close() error {
c.p.cond.L.Lock()
defer c.p.cond.L.Unlock()
if c.p.closed[c.n] {
return errors.New("connection closed")
}
c.p.closed[c.n] = true
c.p.cond.Broadcast()
return nil
}
func (c *conn) LocalAddr() net.Addr { return &addr{} }
func (c *conn) RemoteAddr() net.Addr { return &addr{} }
func (c *conn) SetDeadline(time.Time) error { return errors.New("not implemented") }
func (c *conn) SetReadDeadline(time.Time) error { return errors.New("not implemented") }
func (c *conn) SetWriteDeadline(time.Time) error { return errors.New("not implemented") }
type addr struct{}
func (addr) Network() string { return "bufferedpipe" }
func (addr) String() string { return "bufferedpipe" }

Просмотреть файл

@ -7,6 +7,8 @@ import (
"context"
"fmt"
"net"
"github.com/Azure/ARO-RP/test/util/bufferedpipe"
)
type addr struct{}
@ -46,7 +48,11 @@ func (*Listener) Addr() net.Addr {
}
func (l *Listener) DialContext(context.Context, string, string) (net.Conn, error) {
c1, c2 := net.Pipe()
c1, c2 := bufferedpipe.New()
l.c <- c1
return c2, nil
}
func (l *Listener) Enqueue(c net.Conn) {
l.c <- c
}