aks-traffic-manager/proxy.go

200 строки
5.2 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/signal"
"regexp"
"strings"
"syscall"
"github.com/Azure/go-autorest/autorest/azure"
log "github.com/sirupsen/logrus"
)
// ProxyServer serves the client request and proxy to ARM endpoint
type ProxyServer struct {
port int
cloud string
logger *log.Logger
client *http.Client
throttle Throttle
}
// ListenAndServe starts the proxy server
func (s *ProxyServer) ListenAndServe() {
mux := http.NewServeMux()
mux.HandleFunc("/", s.handleAzureRequests)
server := &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: s.throttle.NewThrottle(mux),
}
go func() {
s.logger.Info("Start listening")
if err := server.ListenAndServe(); err != nil {
s.logger.Fatal("Failed to listen on proxy address: ", err)
}
}()
// listening to OS shutdown singal
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
<-signalChan
s.logger.Infof("Got shutdown signal, shutting down webhook server gracefully...")
server.Shutdown(context.Background()) // nolint: errcheck
}
func (s *ProxyServer) handleAzureRequests(w http.ResponseWriter, req *http.Request) {
logger := s.logger
logger.Infof("Proxy %s %s", req.Method, req.RequestURI)
defer req.Body.Close()
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
logger.Errorf("Failed to read request: %v", err.Error())
s.handleProxyFailure(w)
return
}
apiVersion := req.URL.Query()["api-version"]
buf, err = enableTCPReset(req.Method, req.RequestURI, apiVersion[0], buf)
if err != nil {
logger.Errorf("Failed to enable TCP Reset: %v", err.Error())
s.handleProxyFailure(w)
return
}
env, err := azure.EnvironmentFromName(s.cloud)
if err != nil {
logger.Errorf("error in getting %s environment: %v", s.cloud, err)
return
}
reqURI := fmt.Sprintf("%s%s", env.ResourceManagerEndpoint, req.RequestURI)
proxyReq, err := http.NewRequest(req.Method, reqURI, bytes.NewReader(buf))
if err != nil {
logger.Errorf("Unable to construct request: %s", err.Error())
s.handleProxyFailure(w)
return
}
copyHeader(proxyReq.Header, req.Header)
resp, err := s.client.Do(proxyReq)
if err != nil {
logger.Errorf("Failed to proxy request: %v", err.Error())
s.handleProxyFailure(w)
return
}
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
defer resp.Body.Close()
if _, err := io.Copy(w, resp.Body); err != nil {
logger.Errorf("Failed to copy response body: %v", err.Error())
s.handleProxyFailure(w)
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (s *ProxyServer) handleProxyFailure(w http.ResponseWriter) {
w.WriteHeader(502)
_, err := w.Write([]byte("Internal Server Error"))
if err != nil {
s.logger.Errorf("Failed to handle proxy failure: %v", err.Error())
}
}
func enableTCPReset(httpMethod, requestURI, apiVersion string, input []byte) (output []byte, err error) {
if !strings.EqualFold(httpMethod, "PUT") ||
strings.Compare(apiVersion, "2018-07-01") < 0 {
return input, nil
}
var resourceType, _ = getResourceType(requestURI)
if !strings.EqualFold(resourceType, "Microsoft.Network/loadBalancers") {
return input, nil
}
var jsonBody map[string]interface{}
if err := json.Unmarshal(input, &jsonBody); err != nil {
return input, nil
}
sku := jsonBody["sku"]
if sku == nil || !strings.EqualFold(sku.(map[string]interface{})["name"].(string), "Standard") {
return input, nil
}
if jsonBody["properties"] == nil {
return input, nil
}
properties := jsonBody["properties"].(map[string]interface{})
if properties != nil && properties["loadBalancingRules"] != nil {
loadBalancingRules := properties["loadBalancingRules"].([]interface{})
for _, lbrule := range loadBalancingRules {
rule := lbrule.(map[string]interface{})
setEnableTCPReset(rule["properties"].(map[string]interface{}))
}
}
if properties != nil && properties["outboundRules"] != nil {
outboundRules := properties["outboundRules"].([]interface{})
for _, obrule := range outboundRules {
rule := obrule.(map[string]interface{})
setEnableTCPReset(rule["properties"].(map[string]interface{}))
}
}
if properties != nil && properties["inboundNatRules"] != nil {
inboundNatRules := properties["inboundNatRules"].([]interface{})
for _, natrule := range inboundNatRules {
rule := natrule.(map[string]interface{})
setEnableTCPReset(rule["properties"].(map[string]interface{}))
}
}
return json.Marshal(jsonBody)
}
func setEnableTCPReset(properties map[string]interface{}) {
if properties != nil {
properties["enableTcpReset"] = true
}
}
// based on the ParseResourceID method defined here
// https://github.com/Azure/go-autorest/blob/master/autorest/azure/azure.go#L176
func getResourceType(resourceID string) (string, error) {
const resourceIDPatternText = `(?i)subscriptions/(.+)/resourceGroups/(.+)/providers/(.+?)/(.+?)/(.+)`
resourceIDPattern := regexp.MustCompile(resourceIDPatternText)
match := resourceIDPattern.FindStringSubmatch(resourceID)
if len(match) == 0 || len(match) < 5 {
return "", fmt.Errorf("parsing failed for %s. Invalid resource Id format", resourceID)
}
return fmt.Sprintf("%s/%s", match[3], match[4]), nil
}