Merge pull request #2007 from facchettos/proxy-refactor-and-test

Add testing to the proxy
This commit is contained in:
Ben Vesel 2022-04-13 10:37:04 -04:00 коммит произвёл GitHub
Родитель 31ee47d9d8 bfb4d4b3ae
Коммит eab506db38
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 114 добавлений и 16 удалений

Просмотреть файл

@ -6,6 +6,7 @@ package proxy
import (
"crypto/tls"
"crypto/x509"
"errors"
"io"
"io/ioutil"
"net"
@ -25,6 +26,7 @@ type Server struct {
KeyFile string
ClientCertFile string
Subnet string
subnet *net.IPNet
}
func (s *Server) Run() error {
@ -32,6 +34,7 @@ func (s *Server) Run() error {
if err != nil {
return err
}
s.subnet = subnet
b, err := ioutil.ReadFile(s.ClientCertFile)
if err != nil {
@ -92,25 +95,38 @@ func (s *Server) Run() error {
return err
}
return http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
return http.Serve(l, http.HandlerFunc(s.proxyHandler))
}
ip, _, err := net.SplitHostPort(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
func (s Server) proxyHandler(w http.ResponseWriter, r *http.Request) {
err := s.validateProxyRequest(w, r)
if err != nil {
return
}
Proxy(s.Log, w, r, 0)
}
if !subnet.Contains(net.ParseIP(ip)) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// validateProxyRequest checks that the request is valid. If not, it writes the
// appropriate http headers and returns an error.
func (s Server) validateProxyRequest(w http.ResponseWriter, r *http.Request) error {
Proxy(s.Log, w, r, 0)
}))
ip, _, err := net.SplitHostPort(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return err
}
if r.Method != http.MethodConnect {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return errors.New("Request is not valid, method is not CONNECT")
}
if !s.subnet.Contains(net.ParseIP(ip)) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return errors.New("Request is not allowed, the originating IP is not part of the allowed subnet")
}
return nil
}
// Proxy takes an HTTP/1.x CONNECT Request and ResponseWriter from the Golang

82
pkg/proxy/proxy_test.go Normal file
Просмотреть файл

@ -0,0 +1,82 @@
package proxy
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestValidation(t *testing.T) {
tests := []struct {
name string
method string
subnet string
hostname string
wantStatus int
wantErr bool
}{
{
name: "get https same subnet",
method: http.MethodGet,
subnet: "127.0.0.1/24",
hostname: "https://127.0.0.2:123",
wantStatus: http.StatusMethodNotAllowed,
wantErr: true,
},
{
name: "connect http same subnet",
method: http.MethodConnect,
subnet: "127.0.0.1/24",
hostname: "127.0.0.2:123",
wantStatus: http.StatusOK,
wantErr: false,
},
{
name: "connect http different subnet",
method: http.MethodConnect,
subnet: "127.0.0.1/24",
hostname: "10.0.0.1:123",
wantStatus: http.StatusForbidden,
wantErr: true,
},
{
name: "wrong hostname",
method: http.MethodGet,
subnet: "127.0.0.1/24",
hostname: "https://127.0.0.1::",
wantStatus: http.StatusBadRequest,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := Server{Subnet: tt.subnet}
_, subnet, err := net.ParseCIDR(server.Subnet)
if err != nil {
t.FailNow()
}
server.subnet = subnet
recorder := httptest.NewRecorder()
request := httptest.NewRequest(tt.method, tt.hostname, nil)
err = server.validateProxyRequest(recorder, request)
if (err != nil && !tt.wantErr) || (err == nil && tt.wantErr) {
t.Error(err)
}
response := recorder.Result()
if response.StatusCode != tt.wantStatus {
fmt.Println(response.StatusCode, tt.wantStatus)
t.Error(tt.hostname)
}
})
}
}