ssh: eliminate some goroutine leaks in tests and examples
This should fix the "Log in goroutine" panic seen in https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df, as well as a few other related leaks. It also helps to verify that none of the functions under test deadlock unexpectedly. See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes. Updates golang/go#58901. Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095 Auto-Submit: Bryan Mills <bcmills@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Nicola Murino <nicola.murino@gmail.com> Reviewed-by: Heschi Kreinick <heschi@google.com>
This commit is contained in:
Родитель
eb61739cd9
Коммит
ff15cd57d1
|
@ -16,6 +16,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
|
@ -98,8 +99,15 @@ func ExampleNewServerConn() {
|
|||
}
|
||||
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
// The incoming Request channel must be serviced.
|
||||
go ssh.DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Service the incoming Channel channel.
|
||||
for newChannel := range chans {
|
||||
|
@ -119,16 +127,22 @@ func ExampleNewServerConn() {
|
|||
// Sessions have out-of-band requests such as "shell",
|
||||
// "pty-req" and "env". Here we handle only the
|
||||
// "shell" request.
|
||||
wg.Add(1)
|
||||
go func(in <-chan *ssh.Request) {
|
||||
for req := range in {
|
||||
req.Reply(req.Type == "shell", nil)
|
||||
}
|
||||
wg.Done()
|
||||
}(requests)
|
||||
|
||||
term := terminal.NewTerminal(channel, "> ")
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer channel.Close()
|
||||
defer func() {
|
||||
channel.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
for {
|
||||
line, err := term.ReadLine()
|
||||
if err != nil {
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func muxPair() (*mux, *mux) {
|
||||
|
@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) {
|
|||
|
||||
magic := "hello world"
|
||||
magicExt := "hello stderr"
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := s.Write([]byte(magic))
|
||||
if err != nil {
|
||||
t.Errorf("Write: %v", err)
|
||||
|
@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) {
|
|||
defer writer.Close()
|
||||
defer mux.Close()
|
||||
|
||||
wDone := make(chan int, 1)
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
||||
t.Errorf("could not fill window: %v", err)
|
||||
}
|
||||
writer.Write(make([]byte, 1))
|
||||
wDone <- 1
|
||||
}()
|
||||
writer.remoteWin.waitWriterBlocked()
|
||||
|
||||
|
@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) {
|
|||
if _, err := reader.SendRequest("hello", true, nil); err == nil {
|
||||
t.Errorf("SendRequest succeeded.")
|
||||
}
|
||||
<-wDone
|
||||
}
|
||||
|
||||
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
|
||||
|
@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) {
|
|||
defer writer.Close()
|
||||
defer mux.Close()
|
||||
|
||||
wDone := make(chan int, 1)
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
||||
t.Errorf("could not fill window: %v", err)
|
||||
}
|
||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
|
||||
t.Errorf("got %v, want EOF for unblock write", err)
|
||||
}
|
||||
wDone <- 1
|
||||
}()
|
||||
|
||||
writer.remoteWin.waitWriterBlocked()
|
||||
reader.Close()
|
||||
<-wDone
|
||||
}
|
||||
|
||||
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
|
||||
|
@ -206,20 +211,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
|
|||
defer writer.Close()
|
||||
defer mux.Close()
|
||||
|
||||
wDone := make(chan int, 1)
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
|
||||
t.Errorf("could not fill window: %v", err)
|
||||
}
|
||||
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
|
||||
t.Errorf("got %v, want EOF for unblock write", err)
|
||||
}
|
||||
wDone <- 1
|
||||
}()
|
||||
|
||||
writer.remoteWin.waitWriterBlocked()
|
||||
mux.Close()
|
||||
<-wDone
|
||||
}
|
||||
|
||||
func TestMuxReject(t *testing.T) {
|
||||
|
@ -227,7 +233,12 @@ func TestMuxReject(t *testing.T) {
|
|||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
ch, ok := <-server.incomingChannels
|
||||
if !ok {
|
||||
t.Error("cannot accept channel")
|
||||
|
@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) {
|
|||
|
||||
var received int
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
for r := range server.incomingRequests {
|
||||
|
@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) {
|
|||
}
|
||||
if ok {
|
||||
t.Errorf("SendRequest(no): %v", ok)
|
||||
|
||||
}
|
||||
|
||||
client.Close()
|
||||
|
@ -389,13 +400,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
|
|||
|
||||
// Wait for the server to send the keepalive message and receive back a
|
||||
// response.
|
||||
select {
|
||||
case err := <-kDone:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("server never received ack")
|
||||
if err := <-kDone; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Confirm client hasn't closed.
|
||||
|
@ -403,13 +409,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) {
|
|||
t.Fatalf("failed to send keepalive: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-kDone:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("server never shut down")
|
||||
// Wait for the server to shut down.
|
||||
if err := <-kDone; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) {
|
|||
defer ch.Close()
|
||||
|
||||
// Wait for the server to close the channel and send the keepalive.
|
||||
select {
|
||||
case <-kDone:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("server never received ack")
|
||||
}
|
||||
<-kDone
|
||||
|
||||
// Make sure the channel closed.
|
||||
if _, ok := <-ch.incomingRequests; ok {
|
||||
|
@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) {
|
|||
t.Fatalf("failed to send keepalive: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-kDone:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("server never shut down")
|
||||
}
|
||||
// Wait for the server to shut down.
|
||||
<-kDone
|
||||
}
|
||||
|
||||
func TestMuxGlobalRequest(t *testing.T) {
|
||||
var sawPeek bool
|
||||
var wg sync.WaitGroup
|
||||
defer func() {
|
||||
wg.Wait()
|
||||
if !sawPeek {
|
||||
t.Errorf("never saw 'peek' request")
|
||||
}
|
||||
}()
|
||||
|
||||
clientMux, serverMux := muxPair()
|
||||
defer serverMux.Close()
|
||||
defer clientMux.Close()
|
||||
|
||||
var seen bool
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for r := range serverMux.incomingRequests {
|
||||
seen = seen || r.Type == "peek"
|
||||
sawPeek = sawPeek || r.Type == "peek"
|
||||
if r.WantReply {
|
||||
err := r.Reply(r.Type == "yes",
|
||||
append([]byte(r.Type), r.Payload...))
|
||||
|
@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) {
|
|||
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
|
||||
ok, data, err)
|
||||
}
|
||||
|
||||
if !seen {
|
||||
t.Errorf("never saw 'peek' request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMuxGlobalRequestUnblock(t *testing.T) {
|
||||
|
@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) {
|
|||
t.Errorf("could not send packet")
|
||||
}
|
||||
|
||||
go a.SendRequest("hello", false, nil)
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
a.SendRequest("hello", false, nil)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
_, ok := <-b.incomingRequests
|
||||
if ok {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
|
@ -27,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client {
|
|||
t.Fatalf("netPipe: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer c1.Close()
|
||||
defer func() {
|
||||
c1.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
conf := ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
|
@ -39,7 +46,11 @@ func dial(handler serverType, t *testing.T) *Client {
|
|||
t.Errorf("Unable to handshake: %v", err)
|
||||
return
|
||||
}
|
||||
go DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
for newCh := range chans {
|
||||
if newCh.ChannelType() != "session" {
|
||||
|
@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client {
|
|||
t.Errorf("Accept: %v", err)
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
handler(ch, inReqs, t)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
if err := conn.Wait(); err != io.EOF {
|
||||
|
@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer session.Close()
|
||||
result := make(chan []byte)
|
||||
|
||||
serverStdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("StdinPipe failed: %v", err)
|
||||
}
|
||||
|
||||
result := make(chan []byte)
|
||||
go func() {
|
||||
defer close(result)
|
||||
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
|
||||
|
@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) {
|
|||
result <- echoedBuf.Bytes()
|
||||
}()
|
||||
|
||||
serverStdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("StdinPipe failed: %v", err)
|
||||
}
|
||||
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
|
||||
if err != nil {
|
||||
t.Errorf("failed to copy origBuf to serverStdin: %v", err)
|
||||
|
@ -648,29 +662,44 @@ func TestSessionID(t *testing.T) {
|
|||
User: "user",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
|
||||
srvErrCh := make(chan error, 1)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, chans, reqs, err := NewServerConn(c1, serverConf)
|
||||
srvErrCh <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
serverID <- conn.SessionID()
|
||||
go DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
for ch := range chans {
|
||||
ch.Reject(Prohibited, "")
|
||||
}
|
||||
}()
|
||||
|
||||
cliErrCh := make(chan error, 1)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
|
||||
cliErrCh <- err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
clientID <- conn.SessionID()
|
||||
go DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
for ch := range chans {
|
||||
ch.Reject(Prohibited, "")
|
||||
}
|
||||
|
@ -738,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) {
|
|||
serverConf.AddHostKey(testSigners["rsa"])
|
||||
serverConf.AddHostKey(testSigners["ecdsa"])
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
connect := func(clientConf *ClientConfig, want string) {
|
||||
var alg string
|
||||
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
|
||||
|
@ -751,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
|
|||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
go NewServerConn(c1, serverConf)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
NewServerConn(c1, serverConf)
|
||||
wg.Done()
|
||||
}()
|
||||
_, _, _, err = NewClientConn(c2, "", clientConf)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientConn: %v", err)
|
||||
|
@ -785,7 +820,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
|
|||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
go NewServerConn(c1, serverConf)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
NewServerConn(c1, serverConf)
|
||||
wg.Done()
|
||||
}()
|
||||
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
|
||||
_, _, _, err = NewClientConn(c2, "", clientConf)
|
||||
if err == nil {
|
||||
|
@ -818,14 +857,22 @@ func TestServerClientAuthCallback(t *testing.T) {
|
|||
User: someUsername,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(wg.Wait)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, chans, reqs, err := NewServerConn(c1, serverConf)
|
||||
if err != nil {
|
||||
t.Errorf("server handshake: %v", err)
|
||||
userCh <- "error"
|
||||
return
|
||||
}
|
||||
go DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
for ch := range chans {
|
||||
ch.Reject(Prohibited, "")
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче