Update lint scripts and CI configs.
This commit is contained in:
Pion 2022-04-24 02:56:40 +00:00 коммит произвёл Sean DuBois
Родитель 5c4fb0e221
Коммит 87a8adce43
33 изменённых файлов: 366 добавлений и 199 удалений

2
.github/workflows/lint.yaml поставляемый
Просмотреть файл

@ -47,5 +47,5 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.31
version: v1.45.2
args: $GOLANGCI_LINT_EXRA_ARGS

Просмотреть файл

@ -15,14 +15,22 @@ linters-settings:
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- deadcode # Finds unused code
- decorder # check declaration order and count of types, constants, variables and functions
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
@ -35,40 +43,62 @@ linters:
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- noctx # noctx finds sending http request without context.Context
- scopelint # Scopelint checks for unpinned variables in go programs
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- structcheck # Finds unused struct fields
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varcheck # Finds unused global variables and constants
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- exhaustivestruct # Checks if all struct's fields are initialized
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- ifshort # Checks that your code uses short syntax for if-statements whenever possible
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
- promlinter # Check Prometheus metrics naming via promlint
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!
issues:

Просмотреть файл

@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) {
t.Error(sErr)
}
gotHello <- struct{}{}
if sErr = server.Close(); sErr != nil {
if sErr = server.Close(); sErr != nil { //nolint:contextcheck
t.Error(sErr)
}
}()
@ -96,7 +96,7 @@ func benchmarkConn(b *testing.B, n int64) {
b.Error(err)
}
for {
if _, cErr = client.Write(hw); cErr != nil {
if _, cErr = client.Write(hw); cErr != nil { //nolint:contextcheck
b.Error(err)
}
}

Просмотреть файл

@ -64,8 +64,6 @@ func TestGetCertificate(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cert, err := cfg.getCertificate(test.serverName)
if err != nil {
t.Fatal(err)

Просмотреть файл

@ -19,27 +19,27 @@ type CipherSuiteID = ciphersuite.ID
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck
)
// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite
@ -197,7 +197,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites fu
for _, id := range ids {
c := cipherSuiteForID(id, nil)
if c == nil {
return nil, &invalidCipherSuite{id}
return nil, &invalidCipherSuiteError{id}
}
cipherSuites = append(cipherSuites, c)
}

Просмотреть файл

@ -1,7 +1,7 @@
package dtls
import (
"crypto/dsa" //nolint
"crypto/dsa" //nolint:staticcheck
"crypto/rand"
"crypto/rsa"
"crypto/tls"

65
conn.go
Просмотреть файл

@ -332,7 +332,7 @@ func (c *Conn) Write(p []byte) (int, error) {
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Epoch: c.getLocalEpoch(),
Epoch: c.state.getLocalEpoch(),
Version: protocol.Version1_2,
},
Content: &protocol.ApplicationData{
@ -346,7 +346,7 @@ func (c *Conn) Write(p []byte) (int, error) {
// Close closes the connection.
func (c *Conn) Close() error {
err := c.close(true)
err := c.close(true) //nolint:contextcheck
c.handshakeLoopsFinished.Wait()
return err
}
@ -489,14 +489,14 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by
SequenceNumber: seq,
}
recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
rawPacket, err := recordlayerHeader.Marshal()
if err != nil {
return nil, err
}
p.record.Header = *recordlayerHeader
rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
rawPacket = append(rawPacket, handshakeFragment...)
if p.shouldEncrypt {
var err error
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
@ -540,12 +540,12 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
offset += contentFragmentLen
headerFragmentRaw, err := headerFragment.Marshal()
fragmentedHandshake, err := headerFragment.Marshal()
if err != nil {
return nil, err
}
fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
}
@ -560,7 +560,10 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
}
func (c *Conn) readAndBuffer(ctx context.Context) error {
bufptr := poolReadBuffer.Get().(*[]byte)
bufptr, ok := poolReadBuffer.Get().(*[]byte)
if !ok {
return errFailedToAccessPoolReadBuffer
}
defer poolReadBuffer.Put(bufptr)
b := *bufptr
@ -587,13 +590,13 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
if hs {
hasHandshake = true
}
switch e := err.(type) {
case nil:
case *errAlert:
var e *alertError
if errors.As(err, &e) {
if e.IsFatalOrCloseNotify() {
return e
}
default:
} else if err != nil {
return e
}
}
@ -623,13 +626,12 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
}
}
}
switch e := err.(type) {
case nil:
case *errAlert:
var e *alertError
if errors.As(err, &e) {
if e.IsFatalOrCloseNotify() {
return e
}
default:
} else if err != nil {
return e
}
}
@ -646,7 +648,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
}
// Validate epoch
remoteEpoch := c.getRemoteEpoch()
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
if h.Epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
@ -707,7 +709,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
continue
}
_ = c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
}
return true, nil, nil
@ -727,7 +729,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
}
markPacketAsValid()
return false, a, &errAlert{content}
return false, a, &alertError{content}
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
@ -740,7 +742,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
newRemoteEpoch := h.Epoch + 1
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
if c.getRemoteEpoch()+1 == newRemoteEpoch {
if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
markPacketAsValid()
}
@ -782,7 +784,7 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Epoch: c.getLocalEpoch(),
Epoch: c.state.getLocalEpoch(),
Version: protocol.Version1_2,
},
Content: &alert.Alert{
@ -848,8 +850,8 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
defer c.handshakeLoopsFinished.Done()
for {
if err := c.readAndBuffer(ctxRead); err != nil {
switch e := err.(type) {
case *errAlert:
var e *alertError
if errors.As(err, &e) {
if !e.IsFatalOrCloseNotify() {
if c.isHandshakeCompletedSuccessfully() {
// Pass the error to Read()
@ -861,9 +863,9 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}
continue // non-fatal alert must not stop read loop
}
case error:
switch err {
case context.DeadlineExceeded, context.Canceled, io.EOF:
} else {
switch {
case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF):
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
@ -876,14 +878,15 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}
}
}
select {
case firstErr <- err:
default:
}
if e, ok := err.(*errAlert); ok {
if e != nil {
if e.IsFatalOrCloseNotify() {
_ = c.close(false)
_ = c.close(false) //nolint:contextcheck
}
}
return
@ -954,18 +957,10 @@ func (c *Conn) setLocalEpoch(epoch uint16) {
c.state.localEpoch.Store(epoch)
}
func (c *Conn) getLocalEpoch() uint16 {
return c.state.localEpoch.Load().(uint16)
}
func (c *Conn) setRemoteEpoch(epoch uint16) {
c.state.remoteEpoch.Store(epoch)
}
func (c *Conn) getRemoteEpoch() uint16 {
return c.state.remoteEpoch.Load().(uint16)
}
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()

Просмотреть файл

@ -7,6 +7,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
@ -37,7 +38,10 @@ func TestContextConfig(t *testing.T) {
defer func() {
_ = listen.Close()
}()
addr := listen.LocalAddr().(*net.UDPAddr)
addr, ok := listen.LocalAddr().(*net.UDPAddr)
if !ok {
t.Fatal("Failed to cast net.UDPAddr")
}
cert, err := selfsign.GenerateSelfSigned()
if err != nil {
@ -133,7 +137,8 @@ func TestContextConfig(t *testing.T) {
d, cancel := dial.f()
conn, err := d()
defer cancel()
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
var netError net.Error
if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck
t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
close(done)
return

Просмотреть файл

@ -120,6 +120,8 @@ func TestReadWriteDeadline(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
var e net.Error
ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
@ -129,22 +131,22 @@ func TestReadWriteDeadline(t *testing.T) {
t.Fatal(err)
}
_, werr := ca.Write(make([]byte, 100))
if e, ok := werr.(net.Error); ok {
if errors.As(werr, &e) {
if !e.Timeout() {
t.Error("Deadline exceeded Write must return Timeout error")
}
if !e.Temporary() {
if !e.Temporary() { //nolint:staticcheck
t.Error("Deadline exceeded Write must return Temporary error")
}
} else {
t.Error("Write must return net.Error error")
}
_, rerr := ca.Read(make([]byte, 100))
if e, ok := rerr.(net.Error); ok {
if errors.As(rerr, &e) {
if !e.Timeout() {
t.Error("Deadline exceeded Read must return Timeout error")
}
if !e.Temporary() {
if !e.Temporary() { //nolint:staticcheck
t.Error("Deadline exceeded Read must return Temporary error")
}
} else {
@ -353,7 +355,7 @@ func TestHandshakeWithAlert(t *testing.T) {
CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
},
errServer: errCipherSuiteNoIntersection,
errClient: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
},
"SignatureSchemesNoIntersection": {
configServer: &Config{
@ -364,7 +366,7 @@ func TestHandshakeWithAlert(t *testing.T) {
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
},
errServer: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
errClient: errNoAvailableSignatureSchemes,
},
}
@ -505,8 +507,8 @@ func TestPSK(t *testing.T) {
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
if !bytes.Equal(test.ServerIdentity, hint) { // nolint
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) // nolint
if !bytes.Equal(test.ServerIdentity, hint) {
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113
}
return []byte{0xAB, 0xC1, 0x23}, nil
@ -558,7 +560,7 @@ func TestPSKHintFail(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
serverAlertError := &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
pskRejected := errPSKRejected
// Limit runtime in case of deadlocks
@ -620,14 +622,15 @@ func TestClientTimeout(t *testing.T) {
c, err := testClient(ctx, ca, conf, true)
if err == nil {
_ = c.Close()
_ = c.Close() //nolint:contextcheck
}
clientErr <- err
}()
// no server!
err := <-clientErr
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
var netErr net.Error
if !errors.As(err, &netErr) || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
}
}
@ -666,7 +669,7 @@ func TestSRTPConfiguration(t *testing.T) {
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantServerError: errServerNoMatchingSRTPProfile,
},
{
@ -858,8 +861,6 @@ func TestClientCertificate(t *testing.T) {
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
ca, cb := dpipe.Pipe()
type result struct {
c *Conn
@ -996,7 +997,7 @@ func TestExtendedMasterSecret(t *testing.T) {
ExtendedMasterSecret: DisableExtendedMasterSecret,
},
expectedClientErr: errClientRequiredButNoServerEMS,
expectedServerErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
},
"Disable_Request_ExtendedMasterSecret": {
clientCfg: &Config{
@ -1015,7 +1016,7 @@ func TestExtendedMasterSecret(t *testing.T) {
serverCfg: &Config{
ExtendedMasterSecret: RequireExtendedMasterSecret,
},
expectedClientErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
expectedServerErr: errServerRequiredButNoClientEMS,
},
"Disable_Disable_ExtendedMasterSecret": {
@ -1145,8 +1146,6 @@ func TestServerCertificate(t *testing.T) {
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
ca, cb := dpipe.Pipe()
type result struct {
@ -1203,8 +1202,8 @@ func TestCipherSuiteConfiguration(t *testing.T) {
Name: "Invalid CipherSuite",
ClientCipherSuites: []CipherSuiteID{0x00},
ServerCipherSuites: []CipherSuiteID{0x00},
WantClientError: &invalidCipherSuite{0x00},
WantServerError: &invalidCipherSuite{0x00},
WantClientError: &invalidCipherSuiteError{0x00},
WantServerError: &invalidCipherSuiteError{0x00},
},
{
Name: "Valid CipherSuites specified",
@ -1218,7 +1217,7 @@ func TestCipherSuiteConfiguration(t *testing.T) {
Name: "CipherSuites mismatch",
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
WantServerError: errCipherSuiteNoIntersection,
},
{
@ -1564,7 +1563,8 @@ func TestServerTimeout(t *testing.T) {
}
_, serverErr := testServer(ctx, cb, config, true)
if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
var netErr net.Error
if !errors.As(serverErr, &netErr) || !netErr.Timeout() {
t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
}
@ -1879,7 +1879,11 @@ func TestMultipleHelloVerifyRequest(t *testing.T) {
if err := record.Unmarshal(resp[:n]); err != nil {
t.Fatal(err)
}
clientHello := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
if !ok {
t.Fatal("Failed to cast MessageClientHello")
}
if !bytes.Equal(clientHello.Cookie, cookie) {
t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie)
}
@ -1959,7 +1963,10 @@ func TestRenegotationInfo(t *testing.T) {
t.Fatal(err)
}
helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
if !ok {
t.Fatal("Failed to cast MessageHelloVerifyRequest")
}
err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
if err != nil {
@ -1978,7 +1985,11 @@ func TestRenegotationInfo(t *testing.T) {
t.Fatal(err)
}
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
if !ok {
t.Fatal("Failed to cast MessageServerHello")
}
gotNegotationInfo := false
for _, v := range serverHello.Extensions {
if _, ok := v.(*extension.RenegotiationInfo); ok {
@ -2052,13 +2063,22 @@ func TestServerNameIndicationExtension(t *testing.T) {
t.Fatal(err)
}
clientHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
if !ok {
t.Fatal("Failed to cast MessageClientHello")
}
gotSNI := false
var actualServerName string
for _, v := range clientHello.Extensions {
if _, ok := v.(*extension.ServerName); ok {
gotSNI = true
actualServerName = v.(*extension.ServerName).ServerName
extensionServerName, ok := v.(*extension.ServerName)
if !ok {
t.Fatal("Failed to cast extension.ServerName")
}
actualServerName = extensionServerName.ServerName
}
}
@ -2223,17 +2243,28 @@ func TestALPNExtension(t *testing.T) {
}
if test.ExpectAlertFromServer {
a := r.Content.(*alert.Alert)
a, ok := r.Content.(*alert.Alert)
if !ok {
t.Fatal("Failed to cast alert.Alert")
}
if a.Description != test.Alert {
t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
}
} else {
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
if !ok {
t.Fatal("Failed to cast handshake.MessageServerHello")
}
var negotiatedProtocol string
for _, v := range serverHello.Extensions {
if _, ok := v.(*extension.ALPN); ok {
e := v.(*extension.ALPN)
e, ok := v.(*extension.ALPN)
if !ok {
t.Fatal("Failed to cast extension.ALPN")
}
negotiatedProtocol = e.ProtocolNameList[0]
// Manipulate ServerHello
@ -2269,7 +2300,11 @@ func TestALPNExtension(t *testing.T) {
t.Fatal(err)
}
a := r2.Content.(*alert.Alert)
a, ok := r2.Content.(*alert.Alert)
if !ok {
t.Fatal("Failed to cast alert.Alert")
}
if a.Description != test.Alert {
t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
}
@ -2328,7 +2363,10 @@ func TestSupportedGroupsExtension(t *testing.T) {
t.Fatal(err)
}
helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
if !ok {
t.Fatal("Failed to cast MessageHelloVerifyRequest")
}
err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
if err != nil {
@ -2347,7 +2385,11 @@ func TestSupportedGroupsExtension(t *testing.T) {
t.Fatal(err)
}
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
if !ok {
t.Fatal("Failed to cast MessageServerHello")
}
gotGroups := false
for _, v := range serverHello.Extensions {
if _, ok := v.(*extension.SupportedEllipticCurves); ok {
@ -2507,7 +2549,12 @@ func (ms *memSessStore) Get(key []byte) (Session, error) {
return Session{}, nil
}
return v.(Session), nil
s, ok := v.(Session)
if !ok {
return Session{}, nil
}
return s, nil
}
func (ms *memSessStore) Del(key []byte) error {

Просмотреть файл

@ -61,6 +61,7 @@ var (
errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113
errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113
)
// FatalError indicates that the DTLS connection is no longer available.
@ -80,37 +81,39 @@ type TimeoutError = protocol.TimeoutError
// HandshakeError indicates that the handshake failed.
type HandshakeError = protocol.HandshakeError
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuite struct {
// errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite.
type invalidCipherSuiteError struct {
id CipherSuiteID
}
func (e *invalidCipherSuite) Error() string {
func (e *invalidCipherSuiteError) Error() string {
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
}
func (e *invalidCipherSuite) Is(err error) bool {
if other, ok := err.(*invalidCipherSuite); ok {
func (e *invalidCipherSuiteError) Is(err error) bool {
var other *invalidCipherSuiteError
if errors.As(err, &other) {
return e.id == other.id
}
return false
}
// errAlert wraps DTLS alert notification as an error
type errAlert struct {
type alertError struct {
*alert.Alert
}
func (e *errAlert) Error() string {
func (e *alertError) Error() string {
return fmt.Sprintf("alert: %s", e.Alert.String())
}
func (e *errAlert) IsFatalOrCloseNotify() bool {
func (e *alertError) IsFatalOrCloseNotify() bool {
return e.Level == alert.Fatal || e.Description == alert.CloseNotify
}
func (e *errAlert) Is(err error) bool {
if other, ok := err.(*errAlert); ok {
func (e *alertError) Is(err error) bool {
var other *alertError
if errors.As(err, &other) {
return e.Level == other.Level && e.Description == other.Description
}
return false
@ -118,14 +121,20 @@ func (e *errAlert) Is(err error) bool {
// netError translates an error from underlying Conn to corresponding net.Error.
func netError(err error) error {
switch err {
case io.EOF, context.Canceled, context.DeadlineExceeded:
switch {
case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
// Return io.EOF and context errors as is.
return err
}
switch e := err.(type) {
case (*net.OpError):
if se, ok := e.Err.(*os.SyscallError); ok {
var (
ne net.Error
opError *net.OpError
se *os.SyscallError
)
if errors.As(err, &opError) {
if errors.As(opError, &se) {
if se.Timeout() {
return &TimeoutError{Err: err}
}
@ -133,8 +142,11 @@ func netError(err error) error {
return &TemporaryError{Err: err}
}
}
case (net.Error):
}
if errors.As(err, &ne) {
return err
}
return &FatalError{Err: err}
}

Просмотреть файл

@ -9,18 +9,11 @@
package dtls
import (
"errors"
"os"
"syscall"
)
func isOpErrorTemporary(err *os.SyscallError) bool {
if ne, ok := err.Err.(syscall.Errno); ok {
switch ne {
case syscall.ECONNREFUSED:
return true
default:
return false
}
}
return false
return errors.Is(err.Err, syscall.ECONNREFUSED)
}

Просмотреть файл

@ -7,6 +7,7 @@
package dtls
import (
"errors"
"net"
"testing"
)
@ -29,14 +30,16 @@ func TestErrorsTemporary(t *testing.T) {
if err == nil {
t.Skip("ECONNREFUSED is not set by system")
}
ne, ok := netError(err).(net.Error)
if !ok {
var ne net.Error
if !errors.As(netError(err), &ne) {
t.Fatalf("netError must return net.Error")
}
if ne.Timeout() {
t.Errorf("%v must not be timeout error", err)
}
if !ne.Temporary() {
if !ne.Temporary() { //nolint:staticcheck
t.Errorf("%v must be temporary error", err)
}
}

Просмотреть файл

@ -65,14 +65,14 @@ func TestErrorNetError(t *testing.T) {
for _, c := range cases {
c := c
t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
ne, ok := c.err.(net.Error)
if !ok {
var ne net.Error
if !errors.As(c.err, &ne) {
t.Fatalf("%T doesn't implement net.Error", c.err)
}
if ne.Timeout() != c.timeout {
t.Errorf("%T.Timeout() should be %v", c.err, c.timeout)
}
if ne.Temporary() != c.temporary {
if ne.Temporary() != c.temporary { //nolint:staticcheck
t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
}
if ne.Error() != c.str {

Просмотреть файл

@ -51,17 +51,10 @@ func Chat(conn io.ReadWriter) {
// Check is a helper to throw errors in the examples
func Check(err error) {
switch e := err.(type) {
case nil:
case (net.Error):
if e.Temporary() {
fmt.Printf("Warning: %v\n", err)
return
}
fmt.Printf("net.Error: %v\n", err)
panic(err)
default:
var netError net.Error
if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck
fmt.Printf("Warning: %v\n", err)
} else if err != nil {
fmt.Printf("error: %v\n", err)
panic(err)
}

Просмотреть файл

@ -221,7 +221,7 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h
}
}
return nil, nil
return nil, nil //nolint:nilnil
}
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {

Просмотреть файл

@ -273,7 +273,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
if state.cipherSuite.IsInitialized() {
return nil, nil
return nil, nil //nolint
}
clientRandom := state.localRandom.MarshalFixed()
@ -335,5 +335,5 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
return nil, nil
return nil, nil //nolint
}

Просмотреть файл

@ -31,17 +31,10 @@ func newHandshakeCache() *handshakeCache {
return &handshakeCache{}
}
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) bool { //nolint
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) {
h.mu.Lock()
defer h.mu.Unlock()
for _, i := range h.cache {
if i.messageSequence == messageSequence &&
i.isClient == isClient {
return false
}
}
h.cache = append(h.cache, &handshakeCacheItem{
data: append([]byte{}, data...),
epoch: epoch,
@ -49,7 +42,6 @@ func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ ha
typ: typ,
isClient: isClient,
})
return true
}
// returns a list handshakes that match the requested rules

Просмотреть файл

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"sync"
"testing"
"time"
@ -280,9 +281,10 @@ func TestHandshaker(t *testing.T) {
}
fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
switch err := fsm.Run(ctx, ca, handshakePreparing); err {
case context.Canceled:
case context.DeadlineExceeded:
err := fsm.Run(ctx, ca, handshakePreparing)
switch {
case errors.Is(err, context.Canceled):
case errors.Is(err, context.DeadlineExceeded):
t.Error("Timeout")
default:
t.Error(err)
@ -311,9 +313,10 @@ func TestHandshaker(t *testing.T) {
}
fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
switch err := fsm.Run(ctx, cb, handshakePreparing); err {
case context.Canceled:
case context.DeadlineExceeded:
err := fsm.Run(ctx, cb, handshakePreparing)
switch {
case errors.Is(err, context.Canceled):
case errors.Is(err, context.DeadlineExceeded):
t.Error("Timeout")
default:
t.Error(err)

Просмотреть файл

@ -96,7 +96,12 @@ func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, erro
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return ccm.(*ciphersuite.CCM).Encrypt(pkt, raw)
cipherSuite, ok := ccm.(*ciphersuite.CCM)
if !ok {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cipherSuite.Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
@ -106,5 +111,10 @@ func (c *AesCcm) Decrypt(raw []byte) ([]byte, error) {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return ccm.(*ciphersuite.CCM).Decrypt(raw)
cipherSuite, ok := ccm.(*ciphersuite.CCM)
if !ok {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cipherSuite.Decrypt(raw)
}

Просмотреть файл

@ -52,26 +52,26 @@ func (i ID) String() string {
// Supported Cipher Suites
const (
// AES-128-CCM
TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:revive,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:revive,stylecheck
// AES-128-GCM-SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:revive,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:revive,stylecheck
// AES-256-CBC-SHA
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:golint,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:golint,stylecheck
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:revive,stylecheck
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_256_CCM_8 ID = 0xc0a9 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae //nolint:golint,stylecheck
TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_256_CCM_8 ID = 0xc0a9 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 //nolint:revive,stylecheck
TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae //nolint:revive,stylecheck
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:golint,stylecheck
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:revive,stylecheck
)
// AuthenticationType controls what authentication method is using during the handshake

Просмотреть файл

@ -91,7 +91,12 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer,
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return gcm.(*ciphersuite.GCM).Encrypt(pkt, raw)
cipherSuite, ok := gcm.(*ciphersuite.GCM)
if !ok {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cipherSuite.Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
@ -101,5 +106,10 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(raw []byte) ([]byte, error) {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return gcm.(*ciphersuite.GCM).Decrypt(raw)
cipherSuite, ok := gcm.(*ciphersuite.GCM)
if !ok {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cipherSuite.Decrypt(raw)
}

Просмотреть файл

@ -97,7 +97,12 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, ra
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cipherSuite.Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
@ -107,5 +112,10 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(raw []byte) ([]byte, error) {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Decrypt(raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cipherSuite.Decrypt(raw)
}

Просмотреть файл

@ -101,15 +101,25 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, r
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cipherSuite.Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Decrypt(raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cipherSuite.Decrypt(raw)
}

Просмотреть файл

@ -93,10 +93,15 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando
func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
cbc := c.cbc.Load()
if cbc == nil { // !c.isInitialized()
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
}
return cipherSuite.Encrypt(pkt, raw)
}
// Decrypt decrypts a single TLS RecordLayer
@ -106,5 +111,10 @@ func (c *TLSPskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cbc.(*ciphersuite.CBC).Decrypt(raw)
cipherSuite, ok := cbc.(*ciphersuite.CBC)
if !ok {
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
}
return cipherSuite.Decrypt(raw)
}

Просмотреть файл

@ -6,6 +6,7 @@ package dpipe
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"testing"
@ -14,11 +15,23 @@ import (
"golang.org/x/net/nettest"
)
var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn")
func TestNetTest(t *testing.T) {
nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) {
ca, cb := Pipe()
return &closePropagator{ca.(*conn), cb.(*conn)},
&closePropagator{cb.(*conn), ca.(*conn)},
caConn, ok := ca.(*conn)
if !ok {
return nil, nil, nil, errFailedToCast
}
cbConn, ok := cb.(*conn)
if !ok {
return nil, nil, nil, errFailedToCast
}
return &closePropagator{caConn, cbConn},
&closePropagator{cbConn, caConn},
func() {
_ = ca.Close()
_ = cb.Close()

Просмотреть файл

@ -245,7 +245,6 @@ func TestRFC3610Vectors(t *testing.T) {
for idx, c := range cases {
c := c
t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) {
t.Parallel()
blk, err := aes.NewCipher(c.AESKey)
if err != nil {
t.Fatalf("could not initialize AES block cipher from key: %v", err)
@ -365,7 +364,12 @@ func TestSealError(t *testing.T) {
c := c
t.Run(name, func(t *testing.T) {
defer func() {
if err := recover(); !errors.Is(err.(error), c.err) {
err, ok := recover().(error)
if !ok {
t.Errorf("expected panic '%v', got '%v'", c.err, err)
}
if !errors.Is(err, c.err) {
t.Errorf("expected panic '%v', got '%v'", c.err, err)
}
}()

Просмотреть файл

@ -39,11 +39,21 @@ func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMa
return nil, err
}
writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode)
if !ok {
return nil, errFailedToCast
}
readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode)
if !ok {
return nil, errFailedToCast
}
return &CBC{
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
writeCBC: writeCBC,
writeMac: localMac,
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
readCBC: readCBC,
readMac: remoteMac,
h: h,
}, nil

Просмотреть файл

@ -13,6 +13,7 @@ var (
errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113
errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113
errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} //nolint:goerr113
)
func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte {

Просмотреть файл

@ -68,7 +68,7 @@ func Curves() map[Curve]bool {
// GenerateKeypair generates a keypair for the given Curve
func GenerateKeypair(c Curve) (*Keypair, error) {
switch c { //nolint:golint
switch c { //nolint:revive
case X25519:
tmp := make([]byte, 32)
if _, err := rand.Read(tmp); err != nil {

Просмотреть файл

@ -84,7 +84,8 @@ func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e
// Timeout implements net.Error.Timeout()
func (e *HandshakeError) Timeout() bool {
if netErr, ok := e.Err.(net.Error); ok {
var netErr net.Error
if errors.As(e.Err, &netErr) {
return netErr.Timeout()
}
return false
@ -92,8 +93,9 @@ func (e *HandshakeError) Timeout() bool {
// Temporary implements net.Error.Temporary()
func (e *HandshakeError) Temporary() bool {
if netErr, ok := e.Err.(net.Error); ok {
return netErr.Temporary()
var netErr net.Error
if errors.As(e.Err, &netErr) {
return netErr.Temporary() //nolint
}
return false
}

Просмотреть файл

@ -7,8 +7,8 @@ import "github.com/pion/dtls/v2/pkg/protocol/extension"
type SRTPProtectionProfile = extension.SRTPProtectionProfile
const (
SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint
SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint
SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint
SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint
SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint:revive,stylecheck
SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint:revive,stylecheck
SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint:revive,stylecheck
SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint:revive,stylecheck
)

Просмотреть файл

@ -75,10 +75,10 @@ func (s *State) serialize() *serializedState {
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
epoch := s.localEpoch.Load().(uint16)
epoch := s.getLocalEpoch()
return &serializedState{
LocalEpoch: epoch,
RemoteEpoch: s.remoteEpoch.Load().(uint16),
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
CipherSuiteID: uint16(s.cipherSuite.ID()),
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
@ -180,7 +180,7 @@ func (s *State) UnmarshalBinary(data []byte) error {
// This allows protocols to use DTLS for key establishment, but
// then use some of the keying material for their own purposes
func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
if s.localEpoch.Load().(uint16) == 0 {
if s.getLocalEpoch() == 0 {
return nil, errHandshakeInProgress
} else if len(context) != 0 {
return nil, errContextUnsupported
@ -199,3 +199,19 @@ func (s *State) ExportKeyingMaterial(label string, context []byte, length int) (
}
return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
}
func (s *State) getRemoteEpoch() uint16 {
remoteEpoch, ok := s.remoteEpoch.Load().(uint16)
if ok {
return remoteEpoch
}
return 0
}
func (s *State) getLocalEpoch() uint16 {
localEpoch, ok := s.localEpoch.Load().(uint16)
if ok {
return localEpoch
}
return 0
}

Просмотреть файл

@ -11,7 +11,7 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil
return 0, false
}
func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { //nolint
func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) {
for _, aSuite := range a {
for _, bSuite := range b {
if aSuite.ID() == bSuite.ID() {