зеркало из https://github.com/Azure/go-ntlmssp.git
152 строки
4.0 KiB
Go
152 строки
4.0 KiB
Go
package ntlmssp
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// GetDomain : parse domain name from based on slashes in the input
|
|
// Need to check for upn as well
|
|
func GetDomain(user string) (string, string, bool) {
|
|
domain := ""
|
|
domainNeeded := false
|
|
|
|
if strings.Contains(user, "\\") {
|
|
ucomponents := strings.SplitN(user, "\\", 2)
|
|
domain = ucomponents[0]
|
|
user = ucomponents[1]
|
|
domainNeeded = true
|
|
} else if strings.Contains(user, "@") {
|
|
domainNeeded = false
|
|
} else {
|
|
domainNeeded = true
|
|
}
|
|
return user, domain, domainNeeded
|
|
}
|
|
|
|
//Negotiator is a http.Roundtripper decorator that automatically
|
|
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
|
|
type Negotiator struct{ http.RoundTripper }
|
|
|
|
//RoundTrip sends the request to the server, handling any authentication
|
|
//re-sends as needed.
|
|
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
|
// Use default round tripper if not provided
|
|
rt := l.RoundTripper
|
|
if rt == nil {
|
|
rt = http.DefaultTransport
|
|
}
|
|
// If it is not basic auth, just round trip the request as usual
|
|
reqauth := authheader(req.Header.Values("Authorization"))
|
|
if !reqauth.IsBasic() {
|
|
return rt.RoundTrip(req)
|
|
}
|
|
reqauthBasic := reqauth.Basic()
|
|
// Save request body
|
|
body := bytes.Buffer{}
|
|
if req.Body != nil {
|
|
_, err = body.ReadFrom(req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Body.Close()
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
|
}
|
|
// first try anonymous, in case the server still finds us
|
|
// authenticated from previous traffic
|
|
req.Header.Del("Authorization")
|
|
res, err = rt.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if res.StatusCode != http.StatusUnauthorized {
|
|
return res, err
|
|
}
|
|
resauth := authheader(res.Header.Values("Www-Authenticate"))
|
|
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
|
|
// Unauthorized, Negotiate not requested, let's try with basic auth
|
|
req.Header.Set("Authorization", string(reqauthBasic))
|
|
io.Copy(ioutil.Discard, res.Body)
|
|
res.Body.Close()
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
|
|
|
res, err = rt.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if res.StatusCode != http.StatusUnauthorized {
|
|
return res, err
|
|
}
|
|
resauth = authheader(res.Header.Values("Www-Authenticate"))
|
|
}
|
|
|
|
if resauth.IsNegotiate() || resauth.IsNTLM() {
|
|
// 401 with request:Basic and response:Negotiate
|
|
io.Copy(ioutil.Discard, res.Body)
|
|
res.Body.Close()
|
|
|
|
// recycle credentials
|
|
u, p, err := reqauth.GetBasicCreds()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// get domain from username
|
|
domain := ""
|
|
u, domain, domainNeeded := GetDomain(u)
|
|
|
|
// send negotiate
|
|
negotiateMessage, err := NewNegotiateMessage(domain, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resauth.IsNTLM() {
|
|
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
|
} else {
|
|
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
|
}
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
|
|
|
res, err = rt.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// receive challenge?
|
|
resauth = authheader(res.Header.Values("Www-Authenticate"))
|
|
challengeMessage, err := resauth.GetData()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
|
|
// Negotiation failed, let client deal with response
|
|
return res, nil
|
|
}
|
|
io.Copy(ioutil.Discard, res.Body)
|
|
res.Body.Close()
|
|
|
|
// send authenticate
|
|
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resauth.IsNTLM() {
|
|
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
|
} else {
|
|
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
|
}
|
|
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
|
|
|
return rt.RoundTrip(req)
|
|
}
|
|
|
|
return res, err
|
|
}
|