ssh: eliminate some goroutine leaks in tests and examples

This should fix the "Log in goroutine" panic seen in
https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df,
as well as a few other related leaks. It also helps to verify that
none of the functions under test deadlock unexpectedly.

See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes.

Updates golang/go#58901.

Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095
Auto-Submit: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
This commit is contained in:
Bryan C. Mills 2023-11-09 09:23:46 -05:00 коммит произвёл Gopher Robot
Родитель eb61739cd9
Коммит ff15cd57d1
3 изменённых файлов: 124 добавлений и 56 удалений

Просмотреть файл

@ -16,6 +16,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
@ -98,8 +99,15 @@ func ExampleNewServerConn() {
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
var wg sync.WaitGroup
defer wg.Wait()
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
wg.Add(1)
go func() {
ssh.DiscardRequests(reqs)
wg.Done()
}()
// Service the incoming Channel channel.
for newChannel := range chans {
@ -119,16 +127,22 @@ func ExampleNewServerConn() {
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
wg.Add(1)
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
}
wg.Done()
}(requests)
term := terminal.NewTerminal(channel, "> ")
wg.Add(1)
go func() {
defer channel.Close()
defer func() {
channel.Close()
wg.Done()
}()
for {
line, err := term.ReadLine()
if err != nil {

Просмотреть файл

@ -10,7 +10,6 @@ import (
"io"
"sync"
"testing"
"time"
)
func muxPair() (*mux, *mux) {
@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) {
magic := "hello world"
magicExt := "hello stderr"
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
_, err := s.Write([]byte(magic))
if err != nil {
t.Errorf("Write: %v", err)
@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) {
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) {
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
<-wDone
}
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) {
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
reader.Close()
<-wDone
}
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
@ -206,20 +211,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
defer writer.Close()
defer mux.Close()
wDone := make(chan int, 1)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
mux.Close()
<-wDone
}
func TestMuxReject(t *testing.T) {
@ -227,7 +233,12 @@ func TestMuxReject(t *testing.T) {
defer server.Close()
defer client.Close()
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
ch, ok := <-server.incomingChannels
if !ok {
t.Error("cannot accept channel")
@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) {
var received int
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
for r := range server.incomingRequests {
@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) {
}
if ok {
t.Errorf("SendRequest(no): %v", ok)
}
client.Close()
@ -389,13 +400,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
// Wait for the server to send the keepalive message and receive back a
// response.
select {
case err := <-kDone:
if err != nil {
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
if err := <-kDone; err != nil {
t.Fatal(err)
}
// Confirm client hasn't closed.
@ -403,13 +409,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
t.Fatalf("failed to send keepalive: %v", err)
}
select {
case err := <-kDone:
if err != nil {
t.Fatal(err)
}
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
// Wait for the server to shut down.
if err := <-kDone; err != nil {
t.Fatal(err)
}
}
@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) {
defer ch.Close()
// Wait for the server to close the channel and send the keepalive.
select {
case <-kDone:
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
}
<-kDone
// Make sure the channel closed.
if _, ok := <-ch.incomingRequests; ok {
@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) {
t.Fatalf("failed to send keepalive: %v", err)
}
select {
case <-kDone:
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
}
// Wait for the server to shut down.
<-kDone
}
func TestMuxGlobalRequest(t *testing.T) {
var sawPeek bool
var wg sync.WaitGroup
defer func() {
wg.Wait()
if !sawPeek {
t.Errorf("never saw 'peek' request")
}
}()
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
var seen bool
wg.Add(1)
go func() {
defer wg.Done()
for r := range serverMux.incomingRequests {
seen = seen || r.Type == "peek"
sawPeek = sawPeek || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) {
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}
if !seen {
t.Errorf("never saw 'peek' request")
}
}
func TestMuxGlobalRequestUnblock(t *testing.T) {
@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) {
t.Errorf("could not send packet")
}
go a.SendRequest("hello", false, nil)
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
a.SendRequest("hello", false, nil)
wg.Done()
}()
_, ok := <-b.incomingRequests
if ok {

Просмотреть файл

@ -13,6 +13,7 @@ import (
"io"
"math/rand"
"net"
"sync"
"testing"
"golang.org/x/crypto/ssh/terminal"
@ -27,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client {
t.Fatalf("netPipe: %v", err)
}
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer c1.Close()
defer func() {
c1.Close()
wg.Done()
}()
conf := ServerConfig{
NoClientAuth: true,
}
@ -39,7 +46,11 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Unable to handshake: %v", err)
return
}
go DiscardRequests(reqs)
wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for newCh := range chans {
if newCh.ChannelType() != "session" {
@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Accept: %v", err)
continue
}
wg.Add(1)
go func() {
handler(ch, inReqs, t)
wg.Done()
}()
}
if err := conn.Wait(); err != io.EOF {
@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) {
t.Fatal(err)
}
defer session.Close()
result := make(chan []byte)
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) {
result <- echoedBuf.Bytes()
}()
serverStdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe failed: %v", err)
}
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@ -648,29 +662,44 @@ func TestSessionID(t *testing.T) {
User: "user",
}
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
srvErrCh := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
conn, chans, reqs, err := NewServerConn(c1, serverConf)
srvErrCh <- err
if err != nil {
return
}
serverID <- conn.SessionID()
go DiscardRequests(reqs)
wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
cliErrCh := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
cliErrCh <- err
if err != nil {
return
}
clientID <- conn.SessionID()
go DiscardRequests(reqs)
wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans {
ch.Reject(Prohibited, "")
}
@ -738,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) {
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@ -751,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
wg.Add(1)
go func() {
NewServerConn(c1, serverConf)
wg.Done()
}()
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
@ -785,7 +820,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
wg.Add(1)
go func() {
NewServerConn(c1, serverConf)
wg.Done()
}()
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
@ -818,14 +857,22 @@ func TestServerClientAuthCallback(t *testing.T) {
User: someUsername,
}
var wg sync.WaitGroup
t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
defer wg.Done()
_, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Errorf("server handshake: %v", err)
userCh <- "error"
return
}
go DiscardRequests(reqs)
wg.Add(1)
go func() {
DiscardRequests(reqs)
wg.Done()
}()
for ch := range chans {
ch.Reject(Prohibited, "")
}