Update CI configs to v0.7.2
Update lint scripts and CI configs.
This commit is contained in:
Родитель
5c4fb0e221
Коммит
87a8adce43
|
@ -47,5 +47,5 @@ jobs:
|
|||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: v1.31
|
||||
version: v1.45.2
|
||||
args: $GOLANGCI_LINT_EXRA_ARGS
|
||||
|
|
|
@ -15,14 +15,22 @@ linters-settings:
|
|||
linters:
|
||||
enable:
|
||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||
- bidichk # Checks for dangerous unicode character sequences
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- contextcheck # check the function whether use a non-inherited context
|
||||
- deadcode # Finds unused code
|
||||
- decorder # check declaration order and count of types, constants, variables and functions
|
||||
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
||||
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||
- dupl # Tool for code clone detection
|
||||
- durationcheck # check for two durations multiplied together
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
|
||||
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
|
||||
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- exportloopref # checks for pointers to enclosing loop variables
|
||||
- forcetypeassert # finds forced type assertions
|
||||
- gci # Gci control golang package import order and make it always deterministic.
|
||||
- gochecknoglobals # Checks that no globals are present in Go code
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
|
@ -35,40 +43,62 @@ linters:
|
|||
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
||||
- goheader # Checks is file header matches to pattern
|
||||
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
||||
- golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
|
||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
||||
- gosec # Inspects source code for security problems
|
||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- grouper # An analyzer to analyze expression groups.
|
||||
- importas # Enforces consistent import aliases
|
||||
- ineffassign # Detects when assignments to existing variables are not used
|
||||
- misspell # Finds commonly misspelled English words in comments
|
||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
|
||||
- noctx # noctx finds sending http request without context.Context
|
||||
- scopelint # Scopelint checks for unpinned variables in go programs
|
||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||
- revive # golint replacement, finds style mistakes
|
||||
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
|
||||
- structcheck # Finds unused struct fields
|
||||
- stylecheck # Stylecheck is a replacement for golint
|
||||
- tagliatelle # Checks the struct tags.
|
||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unconvert # Remove unnecessary type conversions
|
||||
- unparam # Reports unused function parameters
|
||||
- unused # Checks Go code for unused constants, variables, functions and types
|
||||
- varcheck # Finds unused global variables and constants
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
- whitespace # Tool for detection of leading and trailing whitespace
|
||||
disable:
|
||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- exhaustivestruct # Checks if all struct's fields are initialized
|
||||
- forbidigo # Forbids identifiers
|
||||
- funlen # Tool for detection of long functions
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- gomnd # An analyzer to detect magic numbers.
|
||||
- ifshort # Checks that your code uses short syntax for if-statements whenever possible
|
||||
- ireturn # Accept Interfaces, Return Concrete Types
|
||||
- lll # Reports long lines
|
||||
- maintidx # maintidx measures the maintainability index of each function.
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
- maligned # Tool to detect Go structs that would take less memory if their fields were sorted
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||
- prealloc # Finds slice declarations that could potentially be preallocated
|
||||
- promlinter # Check Prometheus metrics naming via promlint
|
||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||
- testpackage # linter that makes you use a separate _test package
|
||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
- varnamelen # checks that the length of a variable's name matches its scope
|
||||
- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||
- wsl # Whitespace Linter - Forces you to use empty lines!
|
||||
|
||||
issues:
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) {
|
|||
t.Error(sErr)
|
||||
}
|
||||
gotHello <- struct{}{}
|
||||
if sErr = server.Close(); sErr != nil {
|
||||
if sErr = server.Close(); sErr != nil { //nolint:contextcheck
|
||||
t.Error(sErr)
|
||||
}
|
||||
}()
|
||||
|
@ -96,7 +96,7 @@ func benchmarkConn(b *testing.B, n int64) {
|
|||
b.Error(err)
|
||||
}
|
||||
for {
|
||||
if _, cErr = client.Write(hw); cErr != nil {
|
||||
if _, cErr = client.Write(hw); cErr != nil { //nolint:contextcheck
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,8 +64,6 @@ func TestGetCertificate(t *testing.T) {
|
|||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cert, err := cfg.getCertificate(test.serverName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -19,27 +19,27 @@ type CipherSuiteID = ciphersuite.ID
|
|||
// Supported Cipher Suites
|
||||
const (
|
||||
// AES-128-CCM
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck
|
||||
|
||||
// AES-128-GCM-SHA256
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
|
||||
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck
|
||||
|
||||
// AES-256-CBC-SHA
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck
|
||||
|
||||
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck
|
||||
|
||||
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck
|
||||
)
|
||||
|
||||
// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite
|
||||
|
@ -197,7 +197,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites fu
|
|||
for _, id := range ids {
|
||||
c := cipherSuiteForID(id, nil)
|
||||
if c == nil {
|
||||
return nil, &invalidCipherSuite{id}
|
||||
return nil, &invalidCipherSuiteError{id}
|
||||
}
|
||||
cipherSuites = append(cipherSuites, c)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package dtls
|
||||
|
||||
import (
|
||||
"crypto/dsa" //nolint
|
||||
"crypto/dsa" //nolint:staticcheck
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
|
|
65
conn.go
65
conn.go
|
@ -332,7 +332,7 @@ func (c *Conn) Write(p []byte) (int, error) {
|
|||
{
|
||||
record: &recordlayer.RecordLayer{
|
||||
Header: recordlayer.Header{
|
||||
Epoch: c.getLocalEpoch(),
|
||||
Epoch: c.state.getLocalEpoch(),
|
||||
Version: protocol.Version1_2,
|
||||
},
|
||||
Content: &protocol.ApplicationData{
|
||||
|
@ -346,7 +346,7 @@ func (c *Conn) Write(p []byte) (int, error) {
|
|||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
err := c.close(true)
|
||||
err := c.close(true) //nolint:contextcheck
|
||||
c.handshakeLoopsFinished.Wait()
|
||||
return err
|
||||
}
|
||||
|
@ -489,14 +489,14 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by
|
|||
SequenceNumber: seq,
|
||||
}
|
||||
|
||||
recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
|
||||
rawPacket, err := recordlayerHeader.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.record.Header = *recordlayerHeader
|
||||
|
||||
rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
|
||||
rawPacket = append(rawPacket, handshakeFragment...)
|
||||
if p.shouldEncrypt {
|
||||
var err error
|
||||
rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
|
||||
|
@ -540,12 +540,12 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
|
|||
|
||||
offset += contentFragmentLen
|
||||
|
||||
headerFragmentRaw, err := headerFragment.Marshal()
|
||||
fragmentedHandshake, err := headerFragment.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
|
||||
fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
|
||||
fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
|
||||
}
|
||||
|
||||
|
@ -560,7 +560,10 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
|
|||
}
|
||||
|
||||
func (c *Conn) readAndBuffer(ctx context.Context) error {
|
||||
bufptr := poolReadBuffer.Get().(*[]byte)
|
||||
bufptr, ok := poolReadBuffer.Get().(*[]byte)
|
||||
if !ok {
|
||||
return errFailedToAccessPoolReadBuffer
|
||||
}
|
||||
defer poolReadBuffer.Put(bufptr)
|
||||
|
||||
b := *bufptr
|
||||
|
@ -587,13 +590,13 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
|
|||
if hs {
|
||||
hasHandshake = true
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
case *errAlert:
|
||||
|
||||
var e *alertError
|
||||
if errors.As(err, &e) {
|
||||
if e.IsFatalOrCloseNotify() {
|
||||
return e
|
||||
}
|
||||
default:
|
||||
} else if err != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
@ -623,13 +626,12 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
case *errAlert:
|
||||
var e *alertError
|
||||
if errors.As(err, &e) {
|
||||
if e.IsFatalOrCloseNotify() {
|
||||
return e
|
||||
}
|
||||
default:
|
||||
} else if err != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
@ -646,7 +648,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
|
|||
}
|
||||
|
||||
// Validate epoch
|
||||
remoteEpoch := c.getRemoteEpoch()
|
||||
remoteEpoch := c.state.getRemoteEpoch()
|
||||
if h.Epoch > remoteEpoch {
|
||||
if h.Epoch > remoteEpoch+1 {
|
||||
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
|
||||
|
@ -707,7 +709,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
|
|||
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
|
||||
continue
|
||||
}
|
||||
_ = c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
|
||||
c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
|
||||
}
|
||||
|
||||
return true, nil, nil
|
||||
|
@ -727,7 +729,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
|
|||
a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
|
||||
}
|
||||
markPacketAsValid()
|
||||
return false, a, &errAlert{content}
|
||||
return false, a, &alertError{content}
|
||||
case *protocol.ChangeCipherSpec:
|
||||
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
|
||||
if enqueue {
|
||||
|
@ -740,7 +742,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
|
|||
newRemoteEpoch := h.Epoch + 1
|
||||
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
|
||||
|
||||
if c.getRemoteEpoch()+1 == newRemoteEpoch {
|
||||
if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
|
||||
c.setRemoteEpoch(newRemoteEpoch)
|
||||
markPacketAsValid()
|
||||
}
|
||||
|
@ -782,7 +784,7 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip
|
|||
{
|
||||
record: &recordlayer.RecordLayer{
|
||||
Header: recordlayer.Header{
|
||||
Epoch: c.getLocalEpoch(),
|
||||
Epoch: c.state.getLocalEpoch(),
|
||||
Version: protocol.Version1_2,
|
||||
},
|
||||
Content: &alert.Alert{
|
||||
|
@ -848,8 +850,8 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
|
|||
defer c.handshakeLoopsFinished.Done()
|
||||
for {
|
||||
if err := c.readAndBuffer(ctxRead); err != nil {
|
||||
switch e := err.(type) {
|
||||
case *errAlert:
|
||||
var e *alertError
|
||||
if errors.As(err, &e) {
|
||||
if !e.IsFatalOrCloseNotify() {
|
||||
if c.isHandshakeCompletedSuccessfully() {
|
||||
// Pass the error to Read()
|
||||
|
@ -861,9 +863,9 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
|
|||
}
|
||||
continue // non-fatal alert must not stop read loop
|
||||
}
|
||||
case error:
|
||||
switch err {
|
||||
case context.DeadlineExceeded, context.Canceled, io.EOF:
|
||||
} else {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF):
|
||||
default:
|
||||
if c.isHandshakeCompletedSuccessfully() {
|
||||
// Keep read loop and pass the read error to Read()
|
||||
|
@ -876,14 +878,15 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case firstErr <- err:
|
||||
default:
|
||||
}
|
||||
|
||||
if e, ok := err.(*errAlert); ok {
|
||||
if e != nil {
|
||||
if e.IsFatalOrCloseNotify() {
|
||||
_ = c.close(false)
|
||||
_ = c.close(false) //nolint:contextcheck
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -954,18 +957,10 @@ func (c *Conn) setLocalEpoch(epoch uint16) {
|
|||
c.state.localEpoch.Store(epoch)
|
||||
}
|
||||
|
||||
func (c *Conn) getLocalEpoch() uint16 {
|
||||
return c.state.localEpoch.Load().(uint16)
|
||||
}
|
||||
|
||||
func (c *Conn) setRemoteEpoch(epoch uint16) {
|
||||
c.state.remoteEpoch.Store(epoch)
|
||||
}
|
||||
|
||||
func (c *Conn) getRemoteEpoch() uint16 {
|
||||
return c.state.remoteEpoch.Load().(uint16)
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.LocalAddr
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.nextConn.LocalAddr()
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -37,7 +38,10 @@ func TestContextConfig(t *testing.T) {
|
|||
defer func() {
|
||||
_ = listen.Close()
|
||||
}()
|
||||
addr := listen.LocalAddr().(*net.UDPAddr)
|
||||
addr, ok := listen.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast net.UDPAddr")
|
||||
}
|
||||
|
||||
cert, err := selfsign.GenerateSelfSigned()
|
||||
if err != nil {
|
||||
|
@ -133,7 +137,8 @@ func TestContextConfig(t *testing.T) {
|
|||
d, cancel := dial.f()
|
||||
conn, err := d()
|
||||
defer cancel()
|
||||
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
||||
var netError net.Error
|
||||
if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck
|
||||
t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
|
||||
close(done)
|
||||
return
|
||||
|
|
115
conn_test.go
115
conn_test.go
|
@ -120,6 +120,8 @@ func TestReadWriteDeadline(t *testing.T) {
|
|||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
var e net.Error
|
||||
|
||||
ca, cb, err := pipeMemory()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -129,22 +131,22 @@ func TestReadWriteDeadline(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
_, werr := ca.Write(make([]byte, 100))
|
||||
if e, ok := werr.(net.Error); ok {
|
||||
if errors.As(werr, &e) {
|
||||
if !e.Timeout() {
|
||||
t.Error("Deadline exceeded Write must return Timeout error")
|
||||
}
|
||||
if !e.Temporary() {
|
||||
if !e.Temporary() { //nolint:staticcheck
|
||||
t.Error("Deadline exceeded Write must return Temporary error")
|
||||
}
|
||||
} else {
|
||||
t.Error("Write must return net.Error error")
|
||||
}
|
||||
_, rerr := ca.Read(make([]byte, 100))
|
||||
if e, ok := rerr.(net.Error); ok {
|
||||
if errors.As(rerr, &e) {
|
||||
if !e.Timeout() {
|
||||
t.Error("Deadline exceeded Read must return Timeout error")
|
||||
}
|
||||
if !e.Temporary() {
|
||||
if !e.Temporary() { //nolint:staticcheck
|
||||
t.Error("Deadline exceeded Read must return Temporary error")
|
||||
}
|
||||
} else {
|
||||
|
@ -353,7 +355,7 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
|
||||
},
|
||||
errServer: errCipherSuiteNoIntersection,
|
||||
errClient: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
},
|
||||
"SignatureSchemesNoIntersection": {
|
||||
configServer: &Config{
|
||||
|
@ -364,7 +366,7 @@ func TestHandshakeWithAlert(t *testing.T) {
|
|||
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
|
||||
},
|
||||
errServer: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
errClient: errNoAvailableSignatureSchemes,
|
||||
},
|
||||
}
|
||||
|
@ -505,8 +507,8 @@ func TestPSK(t *testing.T) {
|
|||
go func() {
|
||||
conf := &Config{
|
||||
PSK: func(hint []byte) ([]byte, error) {
|
||||
if !bytes.Equal(test.ServerIdentity, hint) { // nolint
|
||||
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) // nolint
|
||||
if !bytes.Equal(test.ServerIdentity, hint) {
|
||||
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113
|
||||
}
|
||||
|
||||
return []byte{0xAB, 0xC1, 0x23}, nil
|
||||
|
@ -558,7 +560,7 @@ func TestPSKHintFail(t *testing.T) {
|
|||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
serverAlertError := &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
|
||||
serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
|
||||
pskRejected := errPSKRejected
|
||||
|
||||
// Limit runtime in case of deadlocks
|
||||
|
@ -620,14 +622,15 @@ func TestClientTimeout(t *testing.T) {
|
|||
|
||||
c, err := testClient(ctx, ca, conf, true)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
_ = c.Close() //nolint:contextcheck
|
||||
}
|
||||
clientErr <- err
|
||||
}()
|
||||
|
||||
// no server!
|
||||
err := <-clientErr
|
||||
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
|
||||
var netErr net.Error
|
||||
if !errors.As(err, &netErr) || !netErr.Timeout() {
|
||||
t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
|
||||
}
|
||||
}
|
||||
|
@ -666,7 +669,7 @@ func TestSRTPConfiguration(t *testing.T) {
|
|||
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
|
||||
ServerSRTP: nil,
|
||||
ExpectedProfile: 0,
|
||||
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
WantServerError: errServerNoMatchingSRTPProfile,
|
||||
},
|
||||
{
|
||||
|
@ -858,8 +861,6 @@ func TestClientCertificate(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ca, cb := dpipe.Pipe()
|
||||
type result struct {
|
||||
c *Conn
|
||||
|
@ -996,7 +997,7 @@ func TestExtendedMasterSecret(t *testing.T) {
|
|||
ExtendedMasterSecret: DisableExtendedMasterSecret,
|
||||
},
|
||||
expectedClientErr: errClientRequiredButNoServerEMS,
|
||||
expectedServerErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
},
|
||||
"Disable_Request_ExtendedMasterSecret": {
|
||||
clientCfg: &Config{
|
||||
|
@ -1015,7 +1016,7 @@ func TestExtendedMasterSecret(t *testing.T) {
|
|||
serverCfg: &Config{
|
||||
ExtendedMasterSecret: RequireExtendedMasterSecret,
|
||||
},
|
||||
expectedClientErr: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
expectedServerErr: errServerRequiredButNoClientEMS,
|
||||
},
|
||||
"Disable_Disable_ExtendedMasterSecret": {
|
||||
|
@ -1145,8 +1146,6 @@ func TestServerCertificate(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ca, cb := dpipe.Pipe()
|
||||
|
||||
type result struct {
|
||||
|
@ -1203,8 +1202,8 @@ func TestCipherSuiteConfiguration(t *testing.T) {
|
|||
Name: "Invalid CipherSuite",
|
||||
ClientCipherSuites: []CipherSuiteID{0x00},
|
||||
ServerCipherSuites: []CipherSuiteID{0x00},
|
||||
WantClientError: &invalidCipherSuite{0x00},
|
||||
WantServerError: &invalidCipherSuite{0x00},
|
||||
WantClientError: &invalidCipherSuiteError{0x00},
|
||||
WantServerError: &invalidCipherSuiteError{0x00},
|
||||
},
|
||||
{
|
||||
Name: "Valid CipherSuites specified",
|
||||
|
@ -1218,7 +1217,7 @@ func TestCipherSuiteConfiguration(t *testing.T) {
|
|||
Name: "CipherSuites mismatch",
|
||||
ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
|
||||
ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
|
||||
WantClientError: &errAlert{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
|
||||
WantServerError: errCipherSuiteNoIntersection,
|
||||
},
|
||||
{
|
||||
|
@ -1564,7 +1563,8 @@ func TestServerTimeout(t *testing.T) {
|
|||
}
|
||||
|
||||
_, serverErr := testServer(ctx, cb, config, true)
|
||||
if netErr, ok := serverErr.(net.Error); !ok || !netErr.Timeout() {
|
||||
var netErr net.Error
|
||||
if !errors.As(serverErr, &netErr) || !netErr.Timeout() {
|
||||
t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
|
||||
}
|
||||
|
||||
|
@ -1879,7 +1879,11 @@ func TestMultipleHelloVerifyRequest(t *testing.T) {
|
|||
if err := record.Unmarshal(resp[:n]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clientHello := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
|
||||
clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageClientHello")
|
||||
}
|
||||
|
||||
if !bytes.Equal(clientHello.Cookie, cookie) {
|
||||
t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie)
|
||||
}
|
||||
|
@ -1959,7 +1963,10 @@ func TestRenegotationInfo(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
|
||||
helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageHelloVerifyRequest")
|
||||
}
|
||||
|
||||
err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
|
||||
if err != nil {
|
||||
|
@ -1978,7 +1985,11 @@ func TestRenegotationInfo(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageServerHello")
|
||||
}
|
||||
|
||||
gotNegotationInfo := false
|
||||
for _, v := range serverHello.Extensions {
|
||||
if _, ok := v.(*extension.RenegotiationInfo); ok {
|
||||
|
@ -2052,13 +2063,22 @@ func TestServerNameIndicationExtension(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
|
||||
clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageClientHello")
|
||||
}
|
||||
|
||||
gotSNI := false
|
||||
var actualServerName string
|
||||
for _, v := range clientHello.Extensions {
|
||||
if _, ok := v.(*extension.ServerName); ok {
|
||||
gotSNI = true
|
||||
actualServerName = v.(*extension.ServerName).ServerName
|
||||
extensionServerName, ok := v.(*extension.ServerName)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast extension.ServerName")
|
||||
}
|
||||
|
||||
actualServerName = extensionServerName.ServerName
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2223,17 +2243,28 @@ func TestALPNExtension(t *testing.T) {
|
|||
}
|
||||
|
||||
if test.ExpectAlertFromServer {
|
||||
a := r.Content.(*alert.Alert)
|
||||
a, ok := r.Content.(*alert.Alert)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast alert.Alert")
|
||||
}
|
||||
|
||||
if a.Description != test.Alert {
|
||||
t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
|
||||
}
|
||||
} else {
|
||||
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast handshake.MessageServerHello")
|
||||
}
|
||||
|
||||
var negotiatedProtocol string
|
||||
for _, v := range serverHello.Extensions {
|
||||
if _, ok := v.(*extension.ALPN); ok {
|
||||
e := v.(*extension.ALPN)
|
||||
e, ok := v.(*extension.ALPN)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast extension.ALPN")
|
||||
}
|
||||
|
||||
negotiatedProtocol = e.ProtocolNameList[0]
|
||||
|
||||
// Manipulate ServerHello
|
||||
|
@ -2269,7 +2300,11 @@ func TestALPNExtension(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
a := r2.Content.(*alert.Alert)
|
||||
a, ok := r2.Content.(*alert.Alert)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast alert.Alert")
|
||||
}
|
||||
|
||||
if a.Description != test.Alert {
|
||||
t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
|
||||
}
|
||||
|
@ -2328,7 +2363,10 @@ func TestSupportedGroupsExtension(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
helloVerifyRequest := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
|
||||
helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageHelloVerifyRequest")
|
||||
}
|
||||
|
||||
err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
|
||||
if err != nil {
|
||||
|
@ -2347,7 +2385,11 @@ func TestSupportedGroupsExtension(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverHello := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast MessageServerHello")
|
||||
}
|
||||
|
||||
gotGroups := false
|
||||
for _, v := range serverHello.Extensions {
|
||||
if _, ok := v.(*extension.SupportedEllipticCurves); ok {
|
||||
|
@ -2507,7 +2549,12 @@ func (ms *memSessStore) Get(key []byte) (Session, error) {
|
|||
return Session{}, nil
|
||||
}
|
||||
|
||||
return v.(Session), nil
|
||||
s, ok := v.(Session)
|
||||
if !ok {
|
||||
return Session{}, nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (ms *memSessStore) Del(key []byte) error {
|
||||
|
|
44
errors.go
44
errors.go
|
@ -61,6 +61,7 @@ var (
|
|||
errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113
|
||||
errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113
|
||||
errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113
|
||||
errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113
|
||||
)
|
||||
|
||||
// FatalError indicates that the DTLS connection is no longer available.
|
||||
|
@ -80,37 +81,39 @@ type TimeoutError = protocol.TimeoutError
|
|||
// HandshakeError indicates that the handshake failed.
|
||||
type HandshakeError = protocol.HandshakeError
|
||||
|
||||
// invalidCipherSuite indicates an attempt at using an unsupported cipher suite.
|
||||
type invalidCipherSuite struct {
|
||||
// errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite.
|
||||
type invalidCipherSuiteError struct {
|
||||
id CipherSuiteID
|
||||
}
|
||||
|
||||
func (e *invalidCipherSuite) Error() string {
|
||||
func (e *invalidCipherSuiteError) Error() string {
|
||||
return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id)
|
||||
}
|
||||
|
||||
func (e *invalidCipherSuite) Is(err error) bool {
|
||||
if other, ok := err.(*invalidCipherSuite); ok {
|
||||
func (e *invalidCipherSuiteError) Is(err error) bool {
|
||||
var other *invalidCipherSuiteError
|
||||
if errors.As(err, &other) {
|
||||
return e.id == other.id
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// errAlert wraps DTLS alert notification as an error
|
||||
type errAlert struct {
|
||||
type alertError struct {
|
||||
*alert.Alert
|
||||
}
|
||||
|
||||
func (e *errAlert) Error() string {
|
||||
func (e *alertError) Error() string {
|
||||
return fmt.Sprintf("alert: %s", e.Alert.String())
|
||||
}
|
||||
|
||||
func (e *errAlert) IsFatalOrCloseNotify() bool {
|
||||
func (e *alertError) IsFatalOrCloseNotify() bool {
|
||||
return e.Level == alert.Fatal || e.Description == alert.CloseNotify
|
||||
}
|
||||
|
||||
func (e *errAlert) Is(err error) bool {
|
||||
if other, ok := err.(*errAlert); ok {
|
||||
func (e *alertError) Is(err error) bool {
|
||||
var other *alertError
|
||||
if errors.As(err, &other) {
|
||||
return e.Level == other.Level && e.Description == other.Description
|
||||
}
|
||||
return false
|
||||
|
@ -118,14 +121,20 @@ func (e *errAlert) Is(err error) bool {
|
|||
|
||||
// netError translates an error from underlying Conn to corresponding net.Error.
|
||||
func netError(err error) error {
|
||||
switch err {
|
||||
case io.EOF, context.Canceled, context.DeadlineExceeded:
|
||||
switch {
|
||||
case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
// Return io.EOF and context errors as is.
|
||||
return err
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case (*net.OpError):
|
||||
if se, ok := e.Err.(*os.SyscallError); ok {
|
||||
|
||||
var (
|
||||
ne net.Error
|
||||
opError *net.OpError
|
||||
se *os.SyscallError
|
||||
)
|
||||
|
||||
if errors.As(err, &opError) {
|
||||
if errors.As(opError, &se) {
|
||||
if se.Timeout() {
|
||||
return &TimeoutError{Err: err}
|
||||
}
|
||||
|
@ -133,8 +142,11 @@ func netError(err error) error {
|
|||
return &TemporaryError{Err: err}
|
||||
}
|
||||
}
|
||||
case (net.Error):
|
||||
}
|
||||
|
||||
if errors.As(err, &ne) {
|
||||
return err
|
||||
}
|
||||
|
||||
return &FatalError{Err: err}
|
||||
}
|
||||
|
|
|
@ -9,18 +9,11 @@
|
|||
package dtls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func isOpErrorTemporary(err *os.SyscallError) bool {
|
||||
if ne, ok := err.Err.(syscall.Errno); ok {
|
||||
switch ne {
|
||||
case syscall.ECONNREFUSED:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
return errors.Is(err.Err, syscall.ECONNREFUSED)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
package dtls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
@ -29,14 +30,16 @@ func TestErrorsTemporary(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Skip("ECONNREFUSED is not set by system")
|
||||
}
|
||||
ne, ok := netError(err).(net.Error)
|
||||
if !ok {
|
||||
|
||||
var ne net.Error
|
||||
if !errors.As(netError(err), &ne) {
|
||||
t.Fatalf("netError must return net.Error")
|
||||
}
|
||||
|
||||
if ne.Timeout() {
|
||||
t.Errorf("%v must not be timeout error", err)
|
||||
}
|
||||
if !ne.Temporary() {
|
||||
if !ne.Temporary() { //nolint:staticcheck
|
||||
t.Errorf("%v must be temporary error", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,14 +65,14 @@ func TestErrorNetError(t *testing.T) {
|
|||
for _, c := range cases {
|
||||
c := c
|
||||
t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
|
||||
ne, ok := c.err.(net.Error)
|
||||
if !ok {
|
||||
var ne net.Error
|
||||
if !errors.As(c.err, &ne) {
|
||||
t.Fatalf("%T doesn't implement net.Error", c.err)
|
||||
}
|
||||
if ne.Timeout() != c.timeout {
|
||||
t.Errorf("%T.Timeout() should be %v", c.err, c.timeout)
|
||||
}
|
||||
if ne.Temporary() != c.temporary {
|
||||
if ne.Temporary() != c.temporary { //nolint:staticcheck
|
||||
t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
|
||||
}
|
||||
if ne.Error() != c.str {
|
||||
|
|
|
@ -51,17 +51,10 @@ func Chat(conn io.ReadWriter) {
|
|||
|
||||
// Check is a helper to throw errors in the examples
|
||||
func Check(err error) {
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
case (net.Error):
|
||||
if e.Temporary() {
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("net.Error: %v\n", err)
|
||||
panic(err)
|
||||
default:
|
||||
var netError net.Error
|
||||
if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck
|
||||
fmt.Printf("Warning: %v\n", err)
|
||||
} else if err != nil {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -221,7 +221,7 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h
|
|||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
|
||||
|
|
|
@ -273,7 +273,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
|
|||
|
||||
func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit
|
||||
if state.cipherSuite.IsInitialized() {
|
||||
return nil, nil
|
||||
return nil, nil //nolint
|
||||
}
|
||||
|
||||
clientRandom := state.localRandom.MarshalFixed()
|
||||
|
@ -335,5 +335,5 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon
|
|||
|
||||
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint
|
||||
}
|
||||
|
|
|
@ -31,17 +31,10 @@ func newHandshakeCache() *handshakeCache {
|
|||
return &handshakeCache{}
|
||||
}
|
||||
|
||||
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) bool { //nolint
|
||||
func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, i := range h.cache {
|
||||
if i.messageSequence == messageSequence &&
|
||||
i.isClient == isClient {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
h.cache = append(h.cache, &handshakeCacheItem{
|
||||
data: append([]byte{}, data...),
|
||||
epoch: epoch,
|
||||
|
@ -49,7 +42,6 @@ func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ ha
|
|||
typ: typ,
|
||||
isClient: isClient,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// returns a list handshakes that match the requested rules
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -280,9 +281,10 @@ func TestHandshaker(t *testing.T) {
|
|||
}
|
||||
|
||||
fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
|
||||
switch err := fsm.Run(ctx, ca, handshakePreparing); err {
|
||||
case context.Canceled:
|
||||
case context.DeadlineExceeded:
|
||||
err := fsm.Run(ctx, ca, handshakePreparing)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
t.Error("Timeout")
|
||||
default:
|
||||
t.Error(err)
|
||||
|
@ -311,9 +313,10 @@ func TestHandshaker(t *testing.T) {
|
|||
}
|
||||
|
||||
fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
|
||||
switch err := fsm.Run(ctx, cb, handshakePreparing); err {
|
||||
case context.Canceled:
|
||||
case context.DeadlineExceeded:
|
||||
err := fsm.Run(ctx, cb, handshakePreparing)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
t.Error("Timeout")
|
||||
default:
|
||||
t.Error(err)
|
||||
|
|
|
@ -96,7 +96,12 @@ func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, erro
|
|||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return ccm.(*ciphersuite.CCM).Encrypt(pkt, raw)
|
||||
cipherSuite, ok := ccm.(*ciphersuite.CCM)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Encrypt(pkt, raw)
|
||||
}
|
||||
|
||||
// Decrypt decrypts a single TLS RecordLayer
|
||||
|
@ -106,5 +111,10 @@ func (c *AesCcm) Decrypt(raw []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return ccm.(*ciphersuite.CCM).Decrypt(raw)
|
||||
cipherSuite, ok := ccm.(*ciphersuite.CCM)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Decrypt(raw)
|
||||
}
|
||||
|
|
|
@ -52,26 +52,26 @@ func (i ID) String() string {
|
|||
// Supported Cipher Suites
|
||||
const (
|
||||
// AES-128-CCM
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:revive,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:revive,stylecheck
|
||||
|
||||
// AES-128-GCM-SHA256
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:revive,stylecheck
|
||||
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:revive,stylecheck
|
||||
// AES-256-CBC-SHA
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:golint,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:revive,stylecheck
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:revive,stylecheck
|
||||
|
||||
TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_256_CCM_8 ID = 0xc0a9 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae //nolint:golint,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_256_CCM_8 ID = 0xc0a9 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 //nolint:revive,stylecheck
|
||||
TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae //nolint:revive,stylecheck
|
||||
|
||||
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:golint,stylecheck
|
||||
TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:revive,stylecheck
|
||||
)
|
||||
|
||||
// AuthenticationType controls what authentication method is using during the handshake
|
||||
|
|
|
@ -91,7 +91,12 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer,
|
|||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return gcm.(*ciphersuite.GCM).Encrypt(pkt, raw)
|
||||
cipherSuite, ok := gcm.(*ciphersuite.GCM)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Encrypt(pkt, raw)
|
||||
}
|
||||
|
||||
// Decrypt decrypts a single TLS RecordLayer
|
||||
|
@ -101,5 +106,10 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(raw []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return gcm.(*ciphersuite.GCM).Decrypt(raw)
|
||||
cipherSuite, ok := gcm.(*ciphersuite.GCM)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Decrypt(raw)
|
||||
}
|
||||
|
|
|
@ -97,7 +97,12 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, ra
|
|||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Encrypt(pkt, raw)
|
||||
}
|
||||
|
||||
// Decrypt decrypts a single TLS RecordLayer
|
||||
|
@ -107,5 +112,10 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(raw []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Decrypt(raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Decrypt(raw)
|
||||
}
|
||||
|
|
|
@ -101,15 +101,25 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, r
|
|||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Encrypt(pkt, raw)
|
||||
}
|
||||
|
||||
// Decrypt decrypts a single TLS RecordLayer
|
||||
func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) {
|
||||
cbc := c.cbc.Load()
|
||||
if cbc == nil { // !c.isInitialized()
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Decrypt(raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Decrypt(raw)
|
||||
}
|
||||
|
|
|
@ -93,10 +93,15 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando
|
|||
func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
|
||||
cbc := c.cbc.Load()
|
||||
if cbc == nil { // !c.isInitialized()
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Encrypt(pkt, raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Encrypt(pkt, raw)
|
||||
}
|
||||
|
||||
// Decrypt decrypts a single TLS RecordLayer
|
||||
|
@ -106,5 +111,10 @@ func (c *TLSPskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cbc.(*ciphersuite.CBC).Decrypt(raw)
|
||||
cipherSuite, ok := cbc.(*ciphersuite.CBC)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit)
|
||||
}
|
||||
|
||||
return cipherSuite.Decrypt(raw)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package dpipe
|
|||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
@ -14,11 +15,23 @@ import (
|
|||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn")
|
||||
|
||||
func TestNetTest(t *testing.T) {
|
||||
nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) {
|
||||
ca, cb := Pipe()
|
||||
return &closePropagator{ca.(*conn), cb.(*conn)},
|
||||
&closePropagator{cb.(*conn), ca.(*conn)},
|
||||
caConn, ok := ca.(*conn)
|
||||
if !ok {
|
||||
return nil, nil, nil, errFailedToCast
|
||||
}
|
||||
|
||||
cbConn, ok := cb.(*conn)
|
||||
if !ok {
|
||||
return nil, nil, nil, errFailedToCast
|
||||
}
|
||||
|
||||
return &closePropagator{caConn, cbConn},
|
||||
&closePropagator{cbConn, caConn},
|
||||
func() {
|
||||
_ = ca.Close()
|
||||
_ = cb.Close()
|
||||
|
|
|
@ -245,7 +245,6 @@ func TestRFC3610Vectors(t *testing.T) {
|
|||
for idx, c := range cases {
|
||||
c := c
|
||||
t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
blk, err := aes.NewCipher(c.AESKey)
|
||||
if err != nil {
|
||||
t.Fatalf("could not initialize AES block cipher from key: %v", err)
|
||||
|
@ -365,7 +364,12 @@ func TestSealError(t *testing.T) {
|
|||
c := c
|
||||
t.Run(name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); !errors.Is(err.(error), c.err) {
|
||||
err, ok := recover().(error)
|
||||
if !ok {
|
||||
t.Errorf("expected panic '%v', got '%v'", c.err, err)
|
||||
}
|
||||
|
||||
if !errors.Is(err, c.err) {
|
||||
t.Errorf("expected panic '%v', got '%v'", c.err, err)
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -39,11 +39,21 @@ func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMa
|
|||
return nil, err
|
||||
}
|
||||
|
||||
writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode)
|
||||
if !ok {
|
||||
return nil, errFailedToCast
|
||||
}
|
||||
|
||||
readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode)
|
||||
if !ok {
|
||||
return nil, errFailedToCast
|
||||
}
|
||||
|
||||
return &CBC{
|
||||
writeCBC: cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode),
|
||||
writeCBC: writeCBC,
|
||||
writeMac: localMac,
|
||||
|
||||
readCBC: cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode),
|
||||
readCBC: readCBC,
|
||||
readMac: remoteMac,
|
||||
h: h,
|
||||
}, nil
|
||||
|
|
|
@ -13,6 +13,7 @@ var (
|
|||
errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113
|
||||
errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113
|
||||
errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113
|
||||
errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} //nolint:goerr113
|
||||
)
|
||||
|
||||
func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte {
|
||||
|
|
|
@ -68,7 +68,7 @@ func Curves() map[Curve]bool {
|
|||
|
||||
// GenerateKeypair generates a keypair for the given Curve
|
||||
func GenerateKeypair(c Curve) (*Keypair, error) {
|
||||
switch c { //nolint:golint
|
||||
switch c { //nolint:revive
|
||||
case X25519:
|
||||
tmp := make([]byte, 32)
|
||||
if _, err := rand.Read(tmp); err != nil {
|
||||
|
|
|
@ -84,7 +84,8 @@ func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e
|
|||
|
||||
// Timeout implements net.Error.Timeout()
|
||||
func (e *HandshakeError) Timeout() bool {
|
||||
if netErr, ok := e.Err.(net.Error); ok {
|
||||
var netErr net.Error
|
||||
if errors.As(e.Err, &netErr) {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
|
@ -92,8 +93,9 @@ func (e *HandshakeError) Timeout() bool {
|
|||
|
||||
// Temporary implements net.Error.Temporary()
|
||||
func (e *HandshakeError) Temporary() bool {
|
||||
if netErr, ok := e.Err.(net.Error); ok {
|
||||
return netErr.Temporary()
|
||||
var netErr net.Error
|
||||
if errors.As(e.Err, &netErr) {
|
||||
return netErr.Temporary() //nolint
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import "github.com/pion/dtls/v2/pkg/protocol/extension"
|
|||
type SRTPProtectionProfile = extension.SRTPProtectionProfile
|
||||
|
||||
const (
|
||||
SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint
|
||||
SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint
|
||||
SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint
|
||||
SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint
|
||||
SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint:revive,stylecheck
|
||||
SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint:revive,stylecheck
|
||||
SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint:revive,stylecheck
|
||||
SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint:revive,stylecheck
|
||||
)
|
||||
|
|
24
state.go
24
state.go
|
@ -75,10 +75,10 @@ func (s *State) serialize() *serializedState {
|
|||
localRnd := s.localRandom.MarshalFixed()
|
||||
remoteRnd := s.remoteRandom.MarshalFixed()
|
||||
|
||||
epoch := s.localEpoch.Load().(uint16)
|
||||
epoch := s.getLocalEpoch()
|
||||
return &serializedState{
|
||||
LocalEpoch: epoch,
|
||||
RemoteEpoch: s.remoteEpoch.Load().(uint16),
|
||||
LocalEpoch: s.getLocalEpoch(),
|
||||
RemoteEpoch: s.getRemoteEpoch(),
|
||||
CipherSuiteID: uint16(s.cipherSuite.ID()),
|
||||
MasterSecret: s.masterSecret,
|
||||
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
|
||||
|
@ -180,7 +180,7 @@ func (s *State) UnmarshalBinary(data []byte) error {
|
|||
// This allows protocols to use DTLS for key establishment, but
|
||||
// then use some of the keying material for their own purposes
|
||||
func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
|
||||
if s.localEpoch.Load().(uint16) == 0 {
|
||||
if s.getLocalEpoch() == 0 {
|
||||
return nil, errHandshakeInProgress
|
||||
} else if len(context) != 0 {
|
||||
return nil, errContextUnsupported
|
||||
|
@ -199,3 +199,19 @@ func (s *State) ExportKeyingMaterial(label string, context []byte, length int) (
|
|||
}
|
||||
return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
|
||||
}
|
||||
|
||||
func (s *State) getRemoteEpoch() uint16 {
|
||||
remoteEpoch, ok := s.remoteEpoch.Load().(uint16)
|
||||
if ok {
|
||||
return remoteEpoch
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *State) getLocalEpoch() uint16 {
|
||||
localEpoch, ok := s.localEpoch.Load().(uint16)
|
||||
if ok {
|
||||
return localEpoch
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
2
util.go
2
util.go
|
@ -11,7 +11,7 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil
|
|||
return 0, false
|
||||
}
|
||||
|
||||
func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { //nolint
|
||||
func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) {
|
||||
for _, aSuite := range a {
|
||||
for _, bSuite := range b {
|
||||
if aSuite.ID() == bSuite.ID() {
|
||||
|
|
Загрузка…
Ссылка в новой задаче