diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index dd9c7bef3..836a6490f 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -6,6 +6,7 @@ package proxy import ( "crypto/tls" "crypto/x509" + "errors" "io" "io/ioutil" "net" @@ -24,6 +25,7 @@ type Server struct { KeyFile string ClientCertFile string Subnet string + subnet *net.IPNet } func (s *Server) Run() error { @@ -31,6 +33,7 @@ func (s *Server) Run() error { if err != nil { return err } + s.subnet = subnet b, err := ioutil.ReadFile(s.ClientCertFile) if err != nil { @@ -91,25 +94,37 @@ func (s *Server) Run() error { return err } - return http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } + return http.Serve(l, http.HandlerFunc(s.proxyHandler)) +} - ip, _, err := net.SplitHostPort(r.Host) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } +func (s Server) proxyHandler(w http.ResponseWriter, r *http.Request) { + err := s.validateProxyResquest(w, r) + if err != nil { + return + } + Proxy(s.Log, w, r, 0) +} - if !subnet.Contains(net.ParseIP(ip)) { - http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) - return - } +// validateProxyResquest checks that the request is valid. If not, it writes the +// appropriate http headers and returns an error. +func (s Server) validateProxyResquest(w http.ResponseWriter, r *http.Request) error { + if r.Method != http.MethodConnect { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return errors.New("Request is not valid") + } - Proxy(s.Log, w, r, 0) - })) + ip, _, err := net.SplitHostPort(r.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return errors.New("Request is not valid") + } + + if !s.subnet.Contains(net.ParseIP(ip)) { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return errors.New("Request is not valid") + } + + return nil } // Proxy takes an HTTP/1.x CONNECT Request and ResponseWriter from the Golang diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go new file mode 100644 index 000000000..552c7222e --- /dev/null +++ b/pkg/proxy/proxy_test.go @@ -0,0 +1,117 @@ +package proxy + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestProxyRequestValidationMethod(t *testing.T) { + server := Server{Subnet: "127.0.0.1/24"} + _, subnet, err := net.ParseCIDR(server.Subnet) + if err != nil { + t.FailNow() + } + server.subnet = subnet + + //This should fail because the method is not CONNECT + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "https://127.0.0.1:123", nil) + + server.validateProxyResquest(recorder, request) + + response := recorder.Result() + if response.StatusCode != http.StatusMethodNotAllowed { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusMethodNotAllowed, response.StatusCode) + t.FailNow() + } + + //This should succeed because the method is CONNECT + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodConnect, "127.0.0.1:123", nil) + + server.validateProxyResquest(recorder, request) + + response = recorder.Result() + + if response.StatusCode != http.StatusOK { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) + t.FailNow() + } + +} + +func TestProxyRequestValidationHostname(t *testing.T) { + + server := Server{Subnet: "127.0.0.1/24"} + _, subnet, err := net.ParseCIDR(server.Subnet) + if err != nil { + t.FailNow() + } + server.subnet = subnet + + //This should fail because the hostname in not valid + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodConnect, "", nil) + + server.validateProxyResquest(recorder, request) + + response := recorder.Result() + + if response.StatusCode != http.StatusBadRequest { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusBadRequest, response.StatusCode) + t.FailNow() + } + + //This should succeed because the hostname is valid + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodConnect, "127.0.0.1:8443", nil) + + server.validateProxyResquest(recorder, request) + + response = recorder.Result() + + if response.StatusCode != http.StatusOK { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) + t.FailNow() + } + +} + +func TestProxyRequestValidationSubnet(t *testing.T) { + + server := Server{Subnet: "127.0.0.1/24"} + _, subnet, err := net.ParseCIDR(server.Subnet) + if err != nil { + t.FailNow() + } + server.subnet = subnet + + //This should succeed because it is in the subnet + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodConnect, "127.0.0.1:1234", nil) + + server.validateProxyResquest(recorder, request) + + response := recorder.Result() + + if response.StatusCode != http.StatusOK { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusOK, response.StatusCode) + t.FailNow() + } + + //This should fail because it is not in the subnet + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodConnect, "10.0.0.1:1234", nil) + + server.validateProxyResquest(recorder, request) + + response = recorder.Result() + + if response.StatusCode != http.StatusForbidden { + t.Logf("Test failed. Reason: was expecting status code to be %d but it was %d", http.StatusForbidden, response.StatusCode) + t.FailNow() + } + +}