зеркало из https://github.com/Azure/ARO-RP.git
Merge pull request #2007 from facchettos/proxy-refactor-and-test
Add testing to the proxy
This commit is contained in:
Коммит
eab506db38
|
@ -6,6 +6,7 @@ package proxy
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -25,6 +26,7 @@ type Server struct {
|
|||
KeyFile string
|
||||
ClientCertFile string
|
||||
Subnet string
|
||||
subnet *net.IPNet
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
|
@ -32,6 +34,7 @@ func (s *Server) Run() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.subnet = subnet
|
||||
|
||||
b, err := ioutil.ReadFile(s.ClientCertFile)
|
||||
if err != nil {
|
||||
|
@ -92,25 +95,38 @@ func (s *Server) Run() error {
|
|||
return err
|
||||
}
|
||||
|
||||
return http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
return http.Serve(l, http.HandlerFunc(s.proxyHandler))
|
||||
}
|
||||
|
||||
ip, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
func (s Server) proxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.validateProxyRequest(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
Proxy(s.Log, w, r, 0)
|
||||
}
|
||||
|
||||
if !subnet.Contains(net.ParseIP(ip)) {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
// validateProxyRequest checks that the request is valid. If not, it writes the
|
||||
// appropriate http headers and returns an error.
|
||||
func (s Server) validateProxyRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
Proxy(s.Log, w, r, 0)
|
||||
}))
|
||||
ip, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
if r.Method != http.MethodConnect {
|
||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||
return errors.New("Request is not valid, method is not CONNECT")
|
||||
}
|
||||
|
||||
if !s.subnet.Contains(net.ParseIP(ip)) {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return errors.New("Request is not allowed, the originating IP is not part of the allowed subnet")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Proxy takes an HTTP/1.x CONNECT Request and ResponseWriter from the Golang
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
package proxy
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
subnet string
|
||||
hostname string
|
||||
wantStatus int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get https same subnet",
|
||||
method: http.MethodGet,
|
||||
subnet: "127.0.0.1/24",
|
||||
hostname: "https://127.0.0.2:123",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "connect http same subnet",
|
||||
method: http.MethodConnect,
|
||||
subnet: "127.0.0.1/24",
|
||||
hostname: "127.0.0.2:123",
|
||||
wantStatus: http.StatusOK,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "connect http different subnet",
|
||||
method: http.MethodConnect,
|
||||
subnet: "127.0.0.1/24",
|
||||
hostname: "10.0.0.1:123",
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong hostname",
|
||||
method: http.MethodGet,
|
||||
subnet: "127.0.0.1/24",
|
||||
hostname: "https://127.0.0.1::",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := Server{Subnet: tt.subnet}
|
||||
_, subnet, err := net.ParseCIDR(server.Subnet)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
server.subnet = subnet
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(tt.method, tt.hostname, nil)
|
||||
|
||||
err = server.validateProxyRequest(recorder, request)
|
||||
if (err != nil && !tt.wantErr) || (err == nil && tt.wantErr) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
response := recorder.Result()
|
||||
|
||||
if response.StatusCode != tt.wantStatus {
|
||||
fmt.Println(response.StatusCode, tt.wantStatus)
|
||||
t.Error(tt.hostname)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче