Make handshake errors from Accept non-fatal
This avoids the servers stopping when one client fails to complete the handshake. Addresses #255
This commit is contained in:
Родитель
3636c18fc0
Коммит
8ad6de55a3
|
@ -142,6 +142,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
|
|||
* [Jeroen de Bruijn](https://github.com/vidavidorra)
|
||||
* [bjdgyc](https://github.com/bjdgyc)
|
||||
* [Jeffrey Stoke (Jeff Ctor)](https://github.com/jeffreystoke) - *Fragmentbuffer Fix*
|
||||
* [Frank Olbricht](https://github.com/folbricht)
|
||||
|
||||
### License
|
||||
MIT License - see [LICENSE](LICENSE) for full text
|
||||
|
|
|
@ -170,7 +170,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNo
|
|||
for _, id := range ids {
|
||||
c := cipherSuiteForID(id)
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id)
|
||||
return nil, &invalidCipherSuite{id}
|
||||
}
|
||||
cipherSuites = append(cipherSuites, c)
|
||||
}
|
||||
|
|
16
conn.go
16
conn.go
|
@ -2,6 +2,7 @@ package dtls
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -879,16 +880,13 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
|
|||
}
|
||||
|
||||
func (c *Conn) translateHandshakeCtxError(err error) error {
|
||||
switch err {
|
||||
case context.Canceled:
|
||||
if c.isHandshakeCompletedSuccessfully() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
case context.DeadlineExceeded:
|
||||
return errHandshakeTimeout
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
|
||||
return nil
|
||||
}
|
||||
return &HandshakeError{err}
|
||||
}
|
||||
|
||||
func (c *Conn) close(byUser bool) error {
|
||||
|
|
|
@ -132,8 +132,8 @@ func TestContextConfig(t *testing.T) {
|
|||
d, cancel := dial.f()
|
||||
conn, err := d()
|
||||
defer cancel()
|
||||
if err != errHandshakeTimeout {
|
||||
t.Errorf("Expected error: '%v', got: '%v'", errHandshakeTimeout, err)
|
||||
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
||||
t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
|
|
87
conn_test.go
87
conn_test.go
|
@ -290,7 +290,7 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
|
||||
cases := map[string]struct {
|
||||
configServer, configClient *Config
|
||||
errServer, errClient interface{}
|
||||
errServer, errClient error
|
||||
}{
|
||||
"CipherSuiteNoIntersection": {
|
||||
configServer: &Config{
|
||||
|
@ -300,7 +300,7 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
|
||||
},
|
||||
errServer: errCipherSuiteNoIntersection,
|
||||
errClient: "alert: Alert LevelFatal: InsufficientSecurity",
|
||||
errClient: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
},
|
||||
"SignatureSchemesNoIntersection": {
|
||||
configServer: &Config{
|
||||
|
@ -311,7 +311,7 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
|
||||
},
|
||||
errServer: "alert: Alert LevelFatal: InsufficientSecurity",
|
||||
errServer: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
errClient: errNoAvailableSignatureSchemes,
|
||||
},
|
||||
}
|
||||
|
@ -328,17 +328,13 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
}()
|
||||
|
||||
_, errServer := testServer(ctx, cb, testCase.configServer, true)
|
||||
if errExp, ok := testCase.errServer.(error); ok && errServer != errExp {
|
||||
t.Fatalf("Server error exp(%v) failed(%v)", errExp, errServer)
|
||||
} else if strExp, ok := testCase.errServer.(string); ok && errServer.Error() != strExp {
|
||||
t.Fatalf("Server error exp(%s) failed(%v)", strExp, errServer)
|
||||
if !errors.Is(errServer, testCase.errServer) {
|
||||
t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
|
||||
}
|
||||
|
||||
errClient := <-clientErr
|
||||
if errExp, ok := testCase.errClient.(error); ok && errClient != errExp {
|
||||
t.Fatalf("Client error exp(%v) failed(%v)", errExp, errClient)
|
||||
} else if strExp, ok := testCase.errClient.(string); ok && errClient.Error() != strExp {
|
||||
t.Fatalf("Client error exp(%s) failed(%v)", strExp, errClient)
|
||||
if !errors.Is(errClient, testCase.errClient) {
|
||||
t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -490,7 +486,7 @@ func TestPSKHintFail(t *testing.T) {
|
|||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
serverAlertError := errors.New("alert: Alert LevelFatal: InternalError")
|
||||
serverAlertError := &errAlert{&alert{alertLevelFatal, alertInternalError}}
|
||||
pskRejected := errors.New("PSK Rejected")
|
||||
|
||||
// Limit runtime in case of deadlocks
|
||||
|
@ -523,11 +519,11 @@ func TestPSKHintFail(t *testing.T) {
|
|||
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
|
||||
}
|
||||
|
||||
if _, err := testServer(ctx, cb, config, false); err.Error() != serverAlertError.Error() {
|
||||
if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) {
|
||||
t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
|
||||
}
|
||||
|
||||
if err := <-clientErr; err != pskRejected {
|
||||
if err := <-clientErr; !errors.Is(err, pskRejected) {
|
||||
t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
|
||||
}
|
||||
}
|
||||
|
@ -558,9 +554,9 @@ func TestClientTimeout(t *testing.T) {
|
|||
}()
|
||||
|
||||
// no server!
|
||||
|
||||
if err := <-clientErr; err != errHandshakeTimeout {
|
||||
t.Fatalf("Client error exp(%v) failed(%v)", errHandshakeTimeout, err)
|
||||
err := <-clientErr
|
||||
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
||||
t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -598,7 +594,7 @@ func TestSRTPConfiguration(t *testing.T) {
|
|||
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
||||
ServerSRTP: nil,
|
||||
ExpectedProfile: 0,
|
||||
WantClientError: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
|
||||
WantClientError: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
WantServerError: errServerNoMatchingSRTPProfile,
|
||||
},
|
||||
{
|
||||
|
@ -626,10 +622,8 @@ func TestSRTPConfiguration(t *testing.T) {
|
|||
}()
|
||||
|
||||
server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
|
||||
if err != nil || test.WantServerError != nil {
|
||||
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
|
||||
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
||||
}
|
||||
if !errors.Is(err, test.WantServerError) {
|
||||
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
||||
}
|
||||
if err == nil {
|
||||
defer func() {
|
||||
|
@ -643,10 +637,8 @@ func TestSRTPConfiguration(t *testing.T) {
|
|||
_ = res.c.Close()
|
||||
}()
|
||||
}
|
||||
if res.err != nil || test.WantClientError != nil {
|
||||
if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
|
||||
t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
||||
}
|
||||
if !errors.Is(res.err, test.WantClientError) {
|
||||
t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
||||
}
|
||||
if res.c == nil {
|
||||
return
|
||||
|
@ -916,7 +908,7 @@ func TestExtendedMasterSecret(t *testing.T) {
|
|||
ExtendedMasterSecret: DisableExtendedMasterSecret,
|
||||
},
|
||||
expectedClientErr: errClientRequiredButNoServerEMS,
|
||||
expectedServerErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
|
||||
expectedServerErr: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
},
|
||||
"Disable_Request_ExtendedMasterSecret": {
|
||||
clientCfg: &Config{
|
||||
|
@ -935,7 +927,7 @@ func TestExtendedMasterSecret(t *testing.T) {
|
|||
serverCfg: &Config{
|
||||
ExtendedMasterSecret: RequireExtendedMasterSecret,
|
||||
},
|
||||
expectedClientErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
|
||||
expectedClientErr: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
expectedServerErr: errServerRequiredButNoClientEMS,
|
||||
},
|
||||
"Disable_Disable_ExtendedMasterSecret": {
|
||||
|
@ -978,16 +970,12 @@ func TestExtendedMasterSecret(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
if tt.expectedClientErr != nil {
|
||||
if res.err.Error() != tt.expectedClientErr.Error() {
|
||||
t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
|
||||
}
|
||||
if !errors.Is(res.err, tt.expectedClientErr) {
|
||||
t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
|
||||
}
|
||||
|
||||
if tt.expectedServerErr != nil {
|
||||
if err.Error() != tt.expectedServerErr.Error() {
|
||||
t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
|
||||
}
|
||||
if !errors.Is(err, tt.expectedServerErr) {
|
||||
t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1126,8 +1114,8 @@ func TestCipherSuiteConfiguration(t *testing.T) {
|
|||
Name: "Invalid CipherSuite",
|
||||
ClientCipherSuites: []CipherSuiteID{0x00},
|
||||
ServerCipherSuites: []CipherSuiteID{0x00},
|
||||
WantClientError: errors.New("CipherSuite with id(0) is not valid"),
|
||||
WantServerError: errors.New("CipherSuite with id(0) is not valid"),
|
||||
WantClientError: &invalidCipherSuite{0x00},
|
||||
WantServerError: &invalidCipherSuite{0x00},
|
||||
},
|
||||
{
|
||||
Name: "Valid CipherSuites specified",
|
||||
|
@ -1140,7 +1128,7 @@ func TestCipherSuiteConfiguration(t *testing.T) {
|
|||
Name: "CipherSuites mismatch",
|
||||
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||
WantClientError: errors.New("alert: Alert LevelFatal: InsufficientSecurity"),
|
||||
WantClientError: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
|
||||
WantServerError: errCipherSuiteNoIntersection,
|
||||
},
|
||||
{
|
||||
|
@ -1181,20 +1169,16 @@ func TestCipherSuiteConfiguration(t *testing.T) {
|
|||
_ = server.Close()
|
||||
}()
|
||||
}
|
||||
if err != nil || test.WantServerError != nil {
|
||||
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
|
||||
t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
||||
}
|
||||
if !errors.Is(err, test.WantServerError) {
|
||||
t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
|
||||
}
|
||||
|
||||
res := <-c
|
||||
if res.err == nil {
|
||||
_ = server.Close()
|
||||
}
|
||||
if res.err != nil || test.WantClientError != nil {
|
||||
if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
|
||||
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
||||
}
|
||||
if !errors.Is(res.err, test.WantClientError) {
|
||||
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1405,8 +1389,9 @@ func TestServerTimeout(t *testing.T) {
|
|||
FlightInterval: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
if _, err := testServer(ctx, cb, config, true); err != errHandshakeTimeout {
|
||||
t.Fatalf("Client error exp(%v) failed(%v)", errHandshakeTimeout, err)
|
||||
_, serverErr := testServer(ctx, cb, config, true)
|
||||
if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
|
||||
t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
|
||||
}
|
||||
|
||||
// Wait a little longer to ensure no additional messages have been sent by the server
|
||||
|
@ -1520,7 +1505,7 @@ func TestProtocolVersionValidation(t *testing.T) {
|
|||
defer wg.Wait()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := testServer(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
|
||||
if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
|
||||
t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
|
||||
}
|
||||
}()
|
||||
|
@ -1648,7 +1633,7 @@ func TestProtocolVersionValidation(t *testing.T) {
|
|||
defer wg.Wait()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := testClient(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
|
||||
if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
|
||||
t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
|
||||
}
|
||||
}()
|
||||
|
|
50
errors.go
50
errors.go
|
@ -36,7 +36,6 @@ var (
|
|||
errCompressionMethodUnset = &FatalError{errors.New("server hello can not be created without a compression method")}
|
||||
errCookieMismatch = &FatalError{errors.New("client+server cookie does not match")}
|
||||
errCookieTooLong = &FatalError{errors.New("cookie must not be longer then 255 bytes")}
|
||||
errHandshakeTimeout = &FatalError{xerrors.Errorf("the connection timed out during the handshake: %w", context.DeadlineExceeded)}
|
||||
errIdentityNoPSK = &FatalError{errors.New("PSK Identity Hint provided but PSK is nil")}
|
||||
errInvalidCertificate = &FatalError{errors.New("no certificate provided")}
|
||||
errInvalidCipherSpec = &FatalError{errors.New("cipher spec invalid")}
|
||||
|
@ -100,6 +99,27 @@ type TimeoutError struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
// HandshakeError indicates that the handshake failed.
|
||||
type HandshakeError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
|
||||
type invalidCipherSuite struct {
|
||||
id CipherSuiteID
|
||||
}
|
||||
|
||||
func (e *invalidCipherSuite) Error() string {
|
||||
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
|
||||
}
|
||||
|
||||
func (e *invalidCipherSuite) Is(err error) bool {
|
||||
if other, ok := err.(*invalidCipherSuite); ok {
|
||||
return e.id == other.id
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Timeout implements net.Error.Timeout()
|
||||
func (*FatalError) Timeout() bool { return false }
|
||||
|
||||
|
@ -144,6 +164,27 @@ func (e *TimeoutError) Unwrap() error { return e.Err }
|
|||
|
||||
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
|
||||
|
||||
// Timeout implements net.Error.Timeout()
|
||||
func (e *HandshakeError) Timeout() bool {
|
||||
if netErr, ok := e.Err.(net.Error); ok {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Temporary implements net.Error.Temporary()
|
||||
func (e *HandshakeError) Temporary() bool {
|
||||
if netErr, ok := e.Err.(net.Error); ok {
|
||||
return netErr.Temporary()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unwrap implements Go1.13 error unwrapper.
|
||||
func (e *HandshakeError) Unwrap() error { return e.Err }
|
||||
|
||||
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }
|
||||
|
||||
// errAlert wraps DTLS alert notification as an error
|
||||
type errAlert struct {
|
||||
*alert
|
||||
|
@ -157,6 +198,13 @@ func (e *errAlert) IsFatalOrCloseNotify() bool {
|
|||
return e.alertLevel == alertLevelFatal || e.alertDescription == alertCloseNotify
|
||||
}
|
||||
|
||||
func (e *errAlert) Is(err error) bool {
|
||||
if other, ok := err.(*errAlert); ok {
|
||||
return e.alertLevel == other.alertLevel && e.alertDescription == other.alertDescription
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// netError translates an error from underlying Conn to corresponding net.Error.
|
||||
func netError(err error) error {
|
||||
switch err {
|
||||
|
|
|
@ -32,6 +32,10 @@ func TestErrorUnwrap(t *testing.T) {
|
|||
&TimeoutError{errExample},
|
||||
[]error{errExample},
|
||||
},
|
||||
{
|
||||
&HandshakeError{errExample},
|
||||
[]error{errExample},
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
|
@ -59,6 +63,8 @@ func TestErrorNetError(t *testing.T) {
|
|||
{&TemporaryError{errExample}, "dtls temporary: an example error", false, true},
|
||||
{&InternalError{errExample}, "dtls internal: an example error", false, false},
|
||||
{&TimeoutError{errExample}, "dtls timeout: an example error", true, true},
|
||||
{&HandshakeError{errExample}, "handshake error: an example error", false, false},
|
||||
{&HandshakeError{&TimeoutError{errExample}}, "handshake error: dtls timeout: an example error", true, true},
|
||||
}
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
|
@ -73,6 +79,9 @@ func TestErrorNetError(t *testing.T) {
|
|||
if ne.Temporary() != c.temporary {
|
||||
t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
|
||||
}
|
||||
if ne.Error() != c.str {
|
||||
t.Errorf("%T.Error() should be %v", c.err, c.str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче