ssh: eliminate some goroutine leaks in tests and examples

This should fix the "Log in goroutine" panic seen in
https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df,
as well as a few other related leaks. It also helps to verify that
none of the functions under test deadlock unexpectedly.

See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes.

Updates golang/go#58901.

Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095
Auto-Submit: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
This commit is contained in:
Bryan C. Mills 2023-11-09 09:23:46 -05:00 коммит произвёл Gopher Robot
Родитель eb61739cd9
Коммит ff15cd57d1
3 изменённых файлов: 124 добавлений и 56 удалений

Просмотреть файл

@ -16,6 +16,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
@ -98,8 +99,15 @@ func ExampleNewServerConn() {
} }
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
var wg sync.WaitGroup
defer wg.Wait()
// The incoming Request channel must be serviced. // The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs) wg.Add(1)
go func() {
ssh.DiscardRequests(reqs)
wg.Done()
}()
// Service the incoming Channel channel. // Service the incoming Channel channel.
for newChannel := range chans { for newChannel := range chans {
@ -119,16 +127,22 @@ func ExampleNewServerConn() {
// Sessions have out-of-band requests such as "shell", // Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the // "pty-req" and "env". Here we handle only the
// "shell" request. // "shell" request.
wg.Add(1)
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request) {
for req := range in { for req := range in {
req.Reply(req.Type == "shell", nil) req.Reply(req.Type == "shell", nil)
} }
wg.Done()
}(requests) }(requests)
term := terminal.NewTerminal(channel, "> ") term := terminal.NewTerminal(channel, "> ")
wg.Add(1)
go func() { go func() {
defer channel.Close() defer func() {
channel.Close()
wg.Done()
}()
for { for {
line, err := term.ReadLine() line, err := term.ReadLine()
if err != nil { if err != nil {

Просмотреть файл

@ -10,7 +10,6 @@ import (
"io" "io"
"sync" "sync"
"testing" "testing"
"time"
) )
func muxPair() (*mux, *mux) { func muxPair() (*mux, *mux) {
@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) {
magic := "hello world" magic := "hello world"
magicExt := "hello stderr" magicExt := "hello stderr"
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
_, err := s.Write([]byte(magic)) _, err := s.Write([]byte(magic))
if err != nil { if err != nil {
t.Errorf("Write: %v", err) t.Errorf("Write: %v", err)
@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) {
defer writer.Close() defer writer.Close()
defer mux.Close() defer mux.Close()
wDone := make(chan int, 1) var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err) t.Errorf("could not fill window: %v", err)
} }
writer.Write(make([]byte, 1)) writer.Write(make([]byte, 1))
wDone <- 1
}() }()
writer.remoteWin.waitWriterBlocked() writer.remoteWin.waitWriterBlocked()
@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) {
if _, err := reader.SendRequest("hello", true, nil); err == nil { if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.") t.Errorf("SendRequest succeeded.")
} }
<-wDone
} }
func TestMuxChannelCloseWriteUnblock(t *testing.T) { func TestMuxChannelCloseWriteUnblock(t *testing.T) {
@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) {
defer writer.Close() defer writer.Close()
defer mux.Close() defer mux.Close()
wDone := make(chan int, 1) var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err) t.Errorf("could not fill window: %v", err)
} }
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err) t.Errorf("got %v, want EOF for unblock write", err)
} }
wDone <- 1
}() }()
writer.remoteWin.waitWriterBlocked() writer.remoteWin.waitWriterBlocked()
reader.Close() reader.Close()
<-wDone
} }
func TestMuxConnectionCloseWriteUnblock(t *testing.T) { func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
@ -206,20 +211,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
defer writer.Close() defer writer.Close()
defer mux.Close() defer mux.Close()
wDone := make(chan int, 1) var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err) t.Errorf("could not fill window: %v", err)
} }
if _, err := writer.Write(make([]byte, 1)); err != io.EOF { if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err) t.Errorf("got %v, want EOF for unblock write", err)
} }
wDone <- 1
}() }()
writer.remoteWin.waitWriterBlocked() writer.remoteWin.waitWriterBlocked()
mux.Close() mux.Close()
<-wDone
} }
func TestMuxReject(t *testing.T) { func TestMuxReject(t *testing.T) {
@ -227,7 +233,12 @@ func TestMuxReject(t *testing.T) {
defer server.Close() defer server.Close()
defer client.Close() defer client.Close()
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
ch, ok := <-server.incomingChannels ch, ok := <-server.incomingChannels
if !ok { if !ok {
t.Error("cannot accept channel") t.Error("cannot accept channel")
@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) {
var received int var received int
var wg sync.WaitGroup var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1) wg.Add(1)
go func() { go func() {
for r := range server.incomingRequests { for r := range server.incomingRequests {
@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) {
} }
if ok { if ok {
t.Errorf("SendRequest(no): %v", ok) t.Errorf("SendRequest(no): %v", ok)
} }
client.Close() client.Close()
@ -389,13 +400,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
// Wait for the server to send the keepalive message and receive back a // Wait for the server to send the keepalive message and receive back a
// response. // response.
select { if err := <-kDone; err != nil {
case err := <-kDone: t.Fatal(err)
if err != nil {
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
} }
// Confirm client hasn't closed. // Confirm client hasn't closed.
@ -403,13 +409,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
t.Fatalf("failed to send keepalive: %v", err) t.Fatalf("failed to send keepalive: %v", err)
} }
select { // Wait for the server to shut down.
case err := <-kDone: if err := <-kDone; err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
} }
} }
@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) {
defer ch.Close() defer ch.Close()
// Wait for the server to close the channel and send the keepalive. // Wait for the server to close the channel and send the keepalive.
select { <-kDone
case <-kDone:
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
}
// Make sure the channel closed. // Make sure the channel closed.
if _, ok := <-ch.incomingRequests; ok { if _, ok := <-ch.incomingRequests; ok {
@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) {
t.Fatalf("failed to send keepalive: %v", err) t.Fatalf("failed to send keepalive: %v", err)
} }
select { // Wait for the server to shut down.
case <-kDone: <-kDone
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
}
} }
func TestMuxGlobalRequest(t *testing.T) { func TestMuxGlobalRequest(t *testing.T) {
var sawPeek bool
var wg sync.WaitGroup
defer func() {
wg.Wait()
if !sawPeek {
t.Errorf("never saw 'peek' request")
}
}()
clientMux, serverMux := muxPair() clientMux, serverMux := muxPair()
defer serverMux.Close() defer serverMux.Close()
defer clientMux.Close() defer clientMux.Close()
var seen bool wg.Add(1)
go func() { go func() {
defer wg.Done()
for r := range serverMux.incomingRequests { for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek" sawPeek = sawPeek || r.Type == "peek"
if r.WantReply { if r.WantReply {
err := r.Reply(r.Type == "yes", err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...)) append([]byte(r.Type), r.Payload...))
@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err) ok, data, err)
} }
if !seen {
t.Errorf("never saw 'peek' request")
}
} }
func TestMuxGlobalRequestUnblock(t *testing.T) { func TestMuxGlobalRequestUnblock(t *testing.T) {
@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) {
t.Errorf("could not send packet") t.Errorf("could not send packet")
} }
go a.SendRequest("hello", false, nil) var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
a.SendRequest("hello", false, nil)
wg.Done()
}()
_, ok := <-b.incomingRequests _, ok := <-b.incomingRequests
if ok { if ok {

Просмотреть файл

@ -13,6 +13,7 @@ import (
"io" "io"
"math/rand" "math/rand"
"net" "net"
"sync"
"testing" "testing"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
@ -27,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client {
t.Fatalf("netPipe: %v", err) t.Fatalf("netPipe: %v", err)
} }
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer c1.Close() defer func() {
c1.Close()
wg.Done()
}()
conf := ServerConfig{ conf := ServerConfig{
NoClientAuth: true, NoClientAuth: true,
} }
@ -39,7 +46,11 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Unable to handshake: %v", err) t.Errorf("Unable to handshake: %v", err)
return return
} }
go DiscardRequests(reqs) wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for newCh := range chans { for newCh := range chans {
if newCh.ChannelType() != "session" { if newCh.ChannelType() != "session" {
@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Accept: %v", err) t.Errorf("Accept: %v", err)
continue continue
} }
wg.Add(1)
go func() { go func() {
handler(ch, inReqs, t) handler(ch, inReqs, t)
wg.Done()
}() }()
} }
if err := conn.Wait(); err != io.EOF { if err := conn.Wait(); err != io.EOF {
@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer session.Close() defer session.Close()
result := make(chan []byte)
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
result := make(chan []byte)
go func() { go func() {
defer close(result) defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) {
result <- echoedBuf.Bytes() result <- echoedBuf.Bytes()
}() }()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil { if err != nil {
t.Errorf("failed to copy origBuf to serverStdin: %v", err) t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@ -648,29 +662,44 @@ func TestSessionID(t *testing.T) {
User: "user", User: "user",
} }
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
srvErrCh := make(chan error, 1) srvErrCh := make(chan error, 1)
wg.Add(1)
go func() { go func() {
defer wg.Done()
conn, chans, reqs, err := NewServerConn(c1, serverConf) conn, chans, reqs, err := NewServerConn(c1, serverConf)
srvErrCh <- err srvErrCh <- err
if err != nil { if err != nil {
return return
} }
serverID <- conn.SessionID() serverID <- conn.SessionID()
go DiscardRequests(reqs) wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans { for ch := range chans {
ch.Reject(Prohibited, "") ch.Reject(Prohibited, "")
} }
}() }()
cliErrCh := make(chan error, 1) cliErrCh := make(chan error, 1)
wg.Add(1)
go func() { go func() {
defer wg.Done()
conn, chans, reqs, err := NewClientConn(c2, "", clientConf) conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
cliErrCh <- err cliErrCh <- err
if err != nil { if err != nil {
return return
} }
clientID <- conn.SessionID() clientID <- conn.SessionID()
go DiscardRequests(reqs) wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans { for ch := range chans {
ch.Reject(Prohibited, "") ch.Reject(Prohibited, "")
} }
@ -738,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) {
serverConf.AddHostKey(testSigners["rsa"]) serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"]) serverConf.AddHostKey(testSigners["ecdsa"])
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
connect := func(clientConf *ClientConfig, want string) { connect := func(clientConf *ClientConfig, want string) {
var alg string var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@ -751,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close() defer c1.Close()
defer c2.Close() defer c2.Close()
go NewServerConn(c1, serverConf) wg.Add(1)
go func() {
NewServerConn(c1, serverConf)
wg.Done()
}()
_, _, _, err = NewClientConn(c2, "", clientConf) _, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil { if err != nil {
t.Fatalf("NewClientConn: %v", err) t.Fatalf("NewClientConn: %v", err)
@ -785,7 +820,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close() defer c1.Close()
defer c2.Close() defer c2.Close()
go NewServerConn(c1, serverConf) wg.Add(1)
go func() {
NewServerConn(c1, serverConf)
wg.Done()
}()
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf) _, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil { if err == nil {
@ -818,14 +857,22 @@ func TestServerClientAuthCallback(t *testing.T) {
User: someUsername, User: someUsername,
} }
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() { go func() {
defer wg.Done()
_, chans, reqs, err := NewServerConn(c1, serverConf) _, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil { if err != nil {
t.Errorf("server handshake: %v", err) t.Errorf("server handshake: %v", err)
userCh <- "error" userCh <- "error"
return return
} }
go DiscardRequests(reqs) wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans { for ch := range chans {
ch.Reject(Prohibited, "") ch.Reject(Prohibited, "")
} }