Make handshake errors from Accept non-fatal

This avoids the servers stopping when one client fails
to complete the handshake.

Addresses #255
This commit is contained in:
folbrich 2020-06-01 07:38:06 -06:00 коммит произвёл Atsushi Watanabe
Родитель 3636c18fc0
Коммит 8ad6de55a3
7 изменённых файлов: 105 добавлений и 64 удалений

Просмотреть файл

@ -142,6 +142,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
* [Jeroen de Bruijn](https://github.com/vidavidorra)
* [bjdgyc](https://github.com/bjdgyc)
* [Jeffrey Stoke (Jeff Ctor)](https://github.com/jeffreystoke) - *Fragmentbuffer Fix*
* [Frank Olbricht](https://github.com/folbricht)
### License
MIT License - see [LICENSE](LICENSE) for full text

Просмотреть файл

@ -170,7 +170,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNo
for _, id := range ids {
c := cipherSuiteForID(id)
if c == nil {
return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id)
return nil, &invalidCipherSuite{id}
}
cipherSuites = append(cipherSuites, c)
}

16
conn.go
Просмотреть файл

@ -2,6 +2,7 @@ package dtls
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -879,16 +880,13 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}
func (c *Conn) translateHandshakeCtxError(err error) error {
switch err {
case context.Canceled:
if c.isHandshakeCompletedSuccessfully() {
return nil
}
return err
case context.DeadlineExceeded:
return errHandshakeTimeout
if err == nil {
return nil
}
return err
if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
return nil
}
return &HandshakeError{err}
}
func (c *Conn) close(byUser bool) error {

Просмотреть файл

@ -132,8 +132,8 @@ func TestContextConfig(t *testing.T) {
d, cancel := dial.f()
conn, err := d()
defer cancel()
if err != errHandshakeTimeout {
t.Errorf("Expected error: '%v', got: '%v'", errHandshakeTimeout, err)
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
close(done)
return
}

Просмотреть файл

@ -290,7 +290,7 @@ func TestHandshakeWithAlert(t *testing.T) {
cases := map[string]struct {
configServer, configClient *Config
errServer, errClient interface{}
errServer, errClient error
}{
"CipherSuiteNoIntersection": {
configServer: &Config{
@ -300,7 +300,7 @@ func TestHandshakeWithAlert(t *testing.T) {
CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
},
errServer: errCipherSuiteNoIntersection,
errClient: "alert: Alert LevelFatal: InsufficientSecurity",
errClient: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
},
"SignatureSchemesNoIntersection": {
configServer: &Config{
@ -311,7 +311,7 @@ func TestHandshakeWithAlert(t *testing.T) {
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
},
errServer: "alert: Alert LevelFatal: InsufficientSecurity",
errServer: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
errClient: errNoAvailableSignatureSchemes,
},
}
@ -328,17 +328,13 @@ func TestHandshakeWithAlert(t *testing.T) {
}()
_, errServer := testServer(ctx, cb, testCase.configServer, true)
if errExp, ok := testCase.errServer.(error); ok && errServer != errExp {
t.Fatalf("Server error exp(%v) failed(%v)", errExp, errServer)
} else if strExp, ok := testCase.errServer.(string); ok && errServer.Error() != strExp {
t.Fatalf("Server error exp(%s) failed(%v)", strExp, errServer)
if !errors.Is(errServer, testCase.errServer) {
t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
}
errClient := <-clientErr
if errExp, ok := testCase.errClient.(error); ok && errClient != errExp {
t.Fatalf("Client error exp(%v) failed(%v)", errExp, errClient)
} else if strExp, ok := testCase.errClient.(string); ok && errClient.Error() != strExp {
t.Fatalf("Client error exp(%s) failed(%v)", strExp, errClient)
if !errors.Is(errClient, testCase.errClient) {
t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient)
}
})
}
@ -490,7 +486,7 @@ func TestPSKHintFail(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
serverAlertError := errors.New("alert: Alert LevelFatal: InternalError")
serverAlertError := &errAlert{&alert{alertLevelFatal, alertInternalError}}
pskRejected := errors.New("PSK Rejected")
// Limit runtime in case of deadlocks
@ -523,11 +519,11 @@ func TestPSKHintFail(t *testing.T) {
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
}
if _, err := testServer(ctx, cb, config, false); err.Error() != serverAlertError.Error() {
if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) {
t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
}
if err := <-clientErr; err != pskRejected {
if err := <-clientErr; !errors.Is(err, pskRejected) {
t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
}
}
@ -558,9 +554,9 @@ func TestClientTimeout(t *testing.T) {
}()
// no server!
if err := <-clientErr; err != errHandshakeTimeout {
t.Fatalf("Client error exp(%v) failed(%v)", errHandshakeTimeout, err)
err := <-clientErr
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
}
}
@ -598,7 +594,7 @@ func TestSRTPConfiguration(t *testing.T) {
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
WantClientError: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
WantServerError: errServerNoMatchingSRTPProfile,
},
{
@ -626,10 +622,8 @@ func TestSRTPConfiguration(t *testing.T) {
}()
server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
if err != nil || test.WantServerError != nil {
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
if !errors.Is(err, test.WantServerError) {
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
if err == nil {
defer func() {
@ -643,10 +637,8 @@ func TestSRTPConfiguration(t *testing.T) {
_ = res.c.Close()
}()
}
if res.err != nil || test.WantClientError != nil {
if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
if !errors.Is(res.err, test.WantClientError) {
t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
if res.c == nil {
return
@ -916,7 +908,7 @@ func TestExtendedMasterSecret(t *testing.T) {
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
expectedClientErr: errClientRequiredButNoServerEMS,
expectedServerErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
expectedServerErr: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
},
"Disable_Request_ExtendedMasterSecret": {
clientCfg: &Config{
@ -935,7 +927,7 @@ func TestExtendedMasterSecret(t *testing.T) {
serverCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
expectedClientErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
expectedClientErr: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
expectedServerErr: errServerRequiredButNoClientEMS,
},
"Disable_Disable_ExtendedMasterSecret": {
@ -978,16 +970,12 @@ func TestExtendedMasterSecret(t *testing.T) {
}
}()
if tt.expectedClientErr != nil {
if res.err.Error() != tt.expectedClientErr.Error() {
t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
}
if !errors.Is(res.err, tt.expectedClientErr) {
t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
}
if tt.expectedServerErr != nil {
if err.Error() != tt.expectedServerErr.Error() {
t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
}
if !errors.Is(err, tt.expectedServerErr) {
t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
}
})
}
@ -1126,8 +1114,8 @@ func TestCipherSuiteConfiguration(t *testing.T) {
Name: "Invalid CipherSuite",
ClientCipherSuites: []CipherSuiteID{0x00},
ServerCipherSuites: []CipherSuiteID{0x00},
WantClientError: errors.New("CipherSuite with id(0) is not valid"),
WantServerError: errors.New("CipherSuite with id(0) is not valid"),
WantClientError: &invalidCipherSuite{0x00},
WantServerError: &invalidCipherSuite{0x00},
},
{
Name: "Valid CipherSuites specified",
@ -1140,7 +1128,7 @@ func TestCipherSuiteConfiguration(t *testing.T) {
Name: "CipherSuites mismatch",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
WantClientError: errors.New("alert: Alert LevelFatal: InsufficientSecurity"),
WantClientError: &errAlert{&alert{alertLevelFatal, alertInsufficientSecurity}},
WantServerError: errCipherSuiteNoIntersection,
},
{
@ -1181,20 +1169,16 @@ func TestCipherSuiteConfiguration(t *testing.T) {
_ = server.Close()
}()
}
if err != nil || test.WantServerError != nil {
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
if !errors.Is(err, test.WantServerError) {
t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
res := <-c
if res.err == nil {
_ = server.Close()
}
if res.err != nil || test.WantClientError != nil {
if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
if !errors.Is(res.err, test.WantClientError) {
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
}
})
}
@ -1405,8 +1389,9 @@ func TestServerTimeout(t *testing.T) {
FlightInterval: 100 * time.Millisecond,
}
if _, err := testServer(ctx, cb, config, true); err != errHandshakeTimeout {
t.Fatalf("Client error exp(%v) failed(%v)", errHandshakeTimeout, err)
_, serverErr := testServer(ctx, cb, config, true)
if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
}
// Wait a little longer to ensure no additional messages have been sent by the server
@ -1520,7 +1505,7 @@ func TestProtocolVersionValidation(t *testing.T) {
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testServer(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()
@ -1648,7 +1633,7 @@ func TestProtocolVersionValidation(t *testing.T) {
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testClient(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()

Просмотреть файл

@ -36,7 +36,6 @@ var (
errCompressionMethodUnset = &FatalError{errors.New("server hello can not be created without a compression method")}
errCookieMismatch = &FatalError{errors.New("client+server cookie does not match")}
errCookieTooLong = &FatalError{errors.New("cookie must not be longer then 255 bytes")}
errHandshakeTimeout = &FatalError{xerrors.Errorf("the connection timed out during the handshake: %w", context.DeadlineExceeded)}
errIdentityNoPSK = &FatalError{errors.New("PSK Identity Hint provided but PSK is nil")}
errInvalidCertificate = &FatalError{errors.New("no certificate provided")}
errInvalidCipherSpec = &FatalError{errors.New("cipher spec invalid")}
@ -100,6 +99,27 @@ type TimeoutError struct {
Err error
}
// HandshakeError indicates that the handshake failed.
type HandshakeError struct {
Err error
}
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuite struct {
id CipherSuiteID
}
func (e *invalidCipherSuite) Error() string {
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
}
func (e *invalidCipherSuite) Is(err error) bool {
if other, ok := err.(*invalidCipherSuite); ok {
return e.id == other.id
}
return false
}
// Timeout implements net.Error.Timeout()
func (*FatalError) Timeout() bool { return false }
@ -144,6 +164,27 @@ func (e *TimeoutError) Unwrap() error { return e.Err }
func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) }
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Timeout()
}
return false
}
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
}
return false
}
// Unwrap implements Go1.13 error unwrapper.
func (e *HandshakeError) Unwrap() error { return e.Err }
func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) }
// errAlert wraps DTLS alert notification as an error
type errAlert struct {
*alert
@ -157,6 +198,13 @@ func (e *errAlert) IsFatalOrCloseNotify() bool {
return e.alertLevel == alertLevelFatal || e.alertDescription == alertCloseNotify
}
func (e *errAlert) Is(err error) bool {
if other, ok := err.(*errAlert); ok {
return e.alertLevel == other.alertLevel && e.alertDescription == other.alertDescription
}
return false
}
// netError translates an error from underlying Conn to corresponding net.Error.
func netError(err error) error {
switch err {

Просмотреть файл

@ -32,6 +32,10 @@ func TestErrorUnwrap(t *testing.T) {
&TimeoutError{errExample},
[]error{errExample},
},
{
&HandshakeError{errExample},
[]error{errExample},
},
}
for _, c := range cases {
c := c
@ -59,6 +63,8 @@ func TestErrorNetError(t *testing.T) {
{&TemporaryError{errExample}, "dtls temporary: an example error", false, true},
{&InternalError{errExample}, "dtls internal: an example error", false, false},
{&TimeoutError{errExample}, "dtls timeout: an example error", true, true},
{&HandshakeError{errExample}, "handshake error: an example error", false, false},
{&HandshakeError{&TimeoutError{errExample}}, "handshake error: dtls timeout: an example error", true, true},
}
for _, c := range cases {
c := c
@ -73,6 +79,9 @@ func TestErrorNetError(t *testing.T) {
if ne.Temporary() != c.temporary {
t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
}
if ne.Error() != c.str {
t.Errorf("%T.Error() should be %v", c.err, c.str)
}
})
}
}