200 строки
5.2 KiB
Go
200 строки
5.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"regexp"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ProxyServer serves the client request and proxy to ARM endpoint
|
|
type ProxyServer struct {
|
|
port int
|
|
cloud string
|
|
logger *log.Logger
|
|
client *http.Client
|
|
throttle Throttle
|
|
}
|
|
|
|
// ListenAndServe starts the proxy server
|
|
func (s *ProxyServer) ListenAndServe() {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/", s.handleAzureRequests)
|
|
|
|
server := &http.Server{
|
|
Addr: fmt.Sprintf(":%d", s.port),
|
|
Handler: s.throttle.NewThrottle(mux),
|
|
}
|
|
|
|
go func() {
|
|
s.logger.Info("Start listening")
|
|
if err := server.ListenAndServe(); err != nil {
|
|
s.logger.Fatal("Failed to listen on proxy address: ", err)
|
|
}
|
|
}()
|
|
|
|
// listening to OS shutdown singal
|
|
signalChan := make(chan os.Signal, 1)
|
|
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
|
|
<-signalChan
|
|
|
|
s.logger.Infof("Got shutdown signal, shutting down webhook server gracefully...")
|
|
|
|
server.Shutdown(context.Background()) // nolint: errcheck
|
|
}
|
|
|
|
func (s *ProxyServer) handleAzureRequests(w http.ResponseWriter, req *http.Request) {
|
|
logger := s.logger
|
|
logger.Infof("Proxy %s %s", req.Method, req.RequestURI)
|
|
|
|
defer req.Body.Close()
|
|
buf, err := ioutil.ReadAll(req.Body)
|
|
if err != nil {
|
|
logger.Errorf("Failed to read request: %v", err.Error())
|
|
s.handleProxyFailure(w)
|
|
return
|
|
}
|
|
|
|
apiVersion := req.URL.Query()["api-version"]
|
|
buf, err = enableTCPReset(req.Method, req.RequestURI, apiVersion[0], buf)
|
|
if err != nil {
|
|
logger.Errorf("Failed to enable TCP Reset: %v", err.Error())
|
|
s.handleProxyFailure(w)
|
|
return
|
|
}
|
|
|
|
env, err := azure.EnvironmentFromName(s.cloud)
|
|
if err != nil {
|
|
logger.Errorf("error in getting %s environment: %v", s.cloud, err)
|
|
return
|
|
}
|
|
|
|
reqURI := fmt.Sprintf("%s%s", env.ResourceManagerEndpoint, req.RequestURI)
|
|
proxyReq, err := http.NewRequest(req.Method, reqURI, bytes.NewReader(buf))
|
|
if err != nil {
|
|
logger.Errorf("Unable to construct request: %s", err.Error())
|
|
s.handleProxyFailure(w)
|
|
return
|
|
}
|
|
|
|
copyHeader(proxyReq.Header, req.Header)
|
|
|
|
resp, err := s.client.Do(proxyReq)
|
|
if err != nil {
|
|
logger.Errorf("Failed to proxy request: %v", err.Error())
|
|
s.handleProxyFailure(w)
|
|
return
|
|
}
|
|
|
|
copyHeader(w.Header(), resp.Header)
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
logger.Errorf("Failed to copy response body: %v", err.Error())
|
|
s.handleProxyFailure(w)
|
|
}
|
|
}
|
|
|
|
func copyHeader(dst, src http.Header) {
|
|
for k, vv := range src {
|
|
for _, v := range vv {
|
|
dst.Add(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *ProxyServer) handleProxyFailure(w http.ResponseWriter) {
|
|
w.WriteHeader(502)
|
|
_, err := w.Write([]byte("Internal Server Error"))
|
|
if err != nil {
|
|
s.logger.Errorf("Failed to handle proxy failure: %v", err.Error())
|
|
}
|
|
}
|
|
|
|
func enableTCPReset(httpMethod, requestURI, apiVersion string, input []byte) (output []byte, err error) {
|
|
if !strings.EqualFold(httpMethod, "PUT") ||
|
|
strings.Compare(apiVersion, "2018-07-01") < 0 {
|
|
return input, nil
|
|
}
|
|
|
|
var resourceType, _ = getResourceType(requestURI)
|
|
if !strings.EqualFold(resourceType, "Microsoft.Network/loadBalancers") {
|
|
return input, nil
|
|
}
|
|
|
|
var jsonBody map[string]interface{}
|
|
if err := json.Unmarshal(input, &jsonBody); err != nil {
|
|
return input, nil
|
|
}
|
|
|
|
sku := jsonBody["sku"]
|
|
if sku == nil || !strings.EqualFold(sku.(map[string]interface{})["name"].(string), "Standard") {
|
|
return input, nil
|
|
}
|
|
|
|
if jsonBody["properties"] == nil {
|
|
return input, nil
|
|
}
|
|
|
|
properties := jsonBody["properties"].(map[string]interface{})
|
|
|
|
if properties != nil && properties["loadBalancingRules"] != nil {
|
|
loadBalancingRules := properties["loadBalancingRules"].([]interface{})
|
|
for _, lbrule := range loadBalancingRules {
|
|
rule := lbrule.(map[string]interface{})
|
|
setEnableTCPReset(rule["properties"].(map[string]interface{}))
|
|
}
|
|
}
|
|
|
|
if properties != nil && properties["outboundRules"] != nil {
|
|
outboundRules := properties["outboundRules"].([]interface{})
|
|
for _, obrule := range outboundRules {
|
|
rule := obrule.(map[string]interface{})
|
|
setEnableTCPReset(rule["properties"].(map[string]interface{}))
|
|
}
|
|
}
|
|
|
|
if properties != nil && properties["inboundNatRules"] != nil {
|
|
inboundNatRules := properties["inboundNatRules"].([]interface{})
|
|
for _, natrule := range inboundNatRules {
|
|
rule := natrule.(map[string]interface{})
|
|
setEnableTCPReset(rule["properties"].(map[string]interface{}))
|
|
}
|
|
}
|
|
|
|
return json.Marshal(jsonBody)
|
|
}
|
|
|
|
func setEnableTCPReset(properties map[string]interface{}) {
|
|
if properties != nil {
|
|
properties["enableTcpReset"] = true
|
|
}
|
|
}
|
|
|
|
// based on the ParseResourceID method defined here
|
|
// https://github.com/Azure/go-autorest/blob/master/autorest/azure/azure.go#L176
|
|
func getResourceType(resourceID string) (string, error) {
|
|
const resourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+)`
|
|
resourceIDPattern := regexp.MustCompile(resourceIDPatternText)
|
|
match := resourceIDPattern.FindStringSubmatch(resourceID)
|
|
|
|
if len(match) == 0 || len(match) < 5 {
|
|
return "", fmt.Errorf("parsing failed for %s. Invalid resource Id format", resourceID)
|
|
}
|
|
|
|
return fmt.Sprintf("%s/%s", match[3], match[4]), nil
|
|
}
|