зеркало из https://github.com/golang/tools.git
internal/jsonrpc2: add the ability to close connections
Also the ability to wait for them to correctly close. Change-Id: I198c9e24a21c04d5c05bae7a4a0f503429ab0346 Reviewed-on: https://go-review.googlesource.com/c/tools/+/231699 Run-TryBot: Ian Cottrell <iancottrell@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Jonathan Amsterdam <jba@google.com> Reviewed-by: Robert Findley <rfindley@google.com>
This commit is contained in:
Родитель
2caf76543d
Коммит
e84ca95fee
|
@ -32,6 +32,9 @@ type Conn struct {
|
|||
stream Stream
|
||||
pendingMu sync.Mutex // protects the pending map
|
||||
pending map[ID]chan *Response
|
||||
|
||||
done chan struct{}
|
||||
err atomic.Value
|
||||
}
|
||||
|
||||
type constError string
|
||||
|
@ -44,6 +47,7 @@ func NewConn(s Stream) *Conn {
|
|||
conn := &Conn{
|
||||
stream: s,
|
||||
pending: make(map[ID]chan *Response),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
@ -162,18 +166,25 @@ func (c *Conn) write(ctx context.Context, msg Message) (int64, error) {
|
|||
return c.stream.Write(ctx, msg)
|
||||
}
|
||||
|
||||
// Run blocks until the connection is terminated, and returns any error that
|
||||
// caused the termination.
|
||||
// Go starts a goroutine to handle the connection.
|
||||
// It must be called exactly once for each Conn.
|
||||
// It returns only when the reader is closed or there is an error in the stream.
|
||||
func (c *Conn) Run(runCtx context.Context, handler Handler) error {
|
||||
// It returns immediately.
|
||||
// You must block on Done() to wait for the connection to shut down.
|
||||
// This is a temporary measure, this should be started automatically in the
|
||||
// future.
|
||||
func (c *Conn) Go(ctx context.Context, handler Handler) {
|
||||
go c.run(ctx, handler)
|
||||
}
|
||||
|
||||
func (c *Conn) run(ctx context.Context, handler Handler) {
|
||||
defer close(c.done)
|
||||
for {
|
||||
// get the next message
|
||||
msg, n, err := c.stream.Read(runCtx)
|
||||
msg, n, err := c.stream.Read(ctx)
|
||||
if err != nil {
|
||||
// The stream failed, we cannot continue. If the client disconnected
|
||||
// normally, we should get ErrDisconnected here.
|
||||
return err
|
||||
// The stream failed, we cannot continue.
|
||||
c.fail(err)
|
||||
return
|
||||
}
|
||||
switch msg := msg.(type) {
|
||||
case Request:
|
||||
|
@ -187,7 +198,7 @@ func (c *Conn) Run(runCtx context.Context, handler Handler) error {
|
|||
} else {
|
||||
labels = labels[:len(labels)-1]
|
||||
}
|
||||
reqCtx, spanDone := event.Start(runCtx, msg.Method(), labels...)
|
||||
reqCtx, spanDone := event.Start(ctx, msg.Method(), labels...)
|
||||
event.Metric(reqCtx,
|
||||
tag.Started.Of(1),
|
||||
tag.ReceivedBytes.Of(n))
|
||||
|
@ -208,6 +219,32 @@ func (c *Conn) Run(runCtx context.Context, handler Handler) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying stream.
|
||||
// This does not wait for the underlying handler to finish, block on the done
|
||||
// channel with <-Done() for that purpose.
|
||||
func (c *Conn) Close() error {
|
||||
return c.stream.Close()
|
||||
}
|
||||
|
||||
// Done returns a channel that will be closed when the processing goroutine has
|
||||
// terminated, which will happen if Close() is called or the underlying
|
||||
// stream is closed.
|
||||
func (c *Conn) Done() <-chan struct{} {
|
||||
return c.done
|
||||
}
|
||||
|
||||
// Err returns an error if there was one from within the processing goroutine.
|
||||
// If err returns non nil, the connection will be already closed or closing.
|
||||
func (c *Conn) Err() error {
|
||||
return c.err.Load().(error)
|
||||
}
|
||||
|
||||
// fail sets a failure condition on the stream and closes it.
|
||||
func (c *Conn) fail(err error) {
|
||||
c.err.Store(err)
|
||||
c.stream.Close()
|
||||
}
|
||||
|
||||
func marshalToRaw(obj interface{}) (json.RawMessage, error) {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
|
|
|
@ -7,13 +7,11 @@ package jsonrpc2_test
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/tools/internal/event/export/eventtest"
|
||||
|
@ -94,21 +92,19 @@ func TestCall(t *testing.T) {
|
|||
|
||||
func prepare(ctx context.Context, t *testing.T, withHeaders bool) (*jsonrpc2.Conn, *jsonrpc2.Conn, func()) {
|
||||
// make a wait group that can be used to wait for the system to shut down
|
||||
wg := &sync.WaitGroup{}
|
||||
aR, bW := io.Pipe()
|
||||
bR, aW := io.Pipe()
|
||||
a := run(ctx, t, withHeaders, aR, aW, wg)
|
||||
b := run(ctx, t, withHeaders, bR, bW, wg)
|
||||
a := run(ctx, withHeaders, aR, aW)
|
||||
b := run(ctx, withHeaders, bR, bW)
|
||||
return a, b, func() {
|
||||
// we close the main writer, this should cascade through the server and
|
||||
// cause normal shutdown of the entire chain
|
||||
aW.Close()
|
||||
// this should then wait for that entire cascade,
|
||||
wg.Wait()
|
||||
a.Close()
|
||||
b.Close()
|
||||
<-a.Done()
|
||||
<-b.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w io.WriteCloser, wg *sync.WaitGroup) *jsonrpc2.Conn {
|
||||
func run(ctx context.Context, withHeaders bool, r io.ReadCloser, w io.WriteCloser) *jsonrpc2.Conn {
|
||||
var stream jsonrpc2.Stream
|
||||
if withHeaders {
|
||||
stream = jsonrpc2.NewHeaderStream(r, w)
|
||||
|
@ -116,18 +112,7 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w
|
|||
stream = jsonrpc2.NewRawStream(r, w)
|
||||
}
|
||||
conn := jsonrpc2.NewConn(stream)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
stream.Close()
|
||||
// and then signal that this connection is done
|
||||
wg.Done()
|
||||
}()
|
||||
err := conn.Run(ctx, testHandler(*logRPC))
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) {
|
||||
t.Errorf("Stream failed: %v", err)
|
||||
}
|
||||
}()
|
||||
conn.Go(ctx, testHandler(*logRPC))
|
||||
return conn
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,9 @@ func (f ServerFunc) ServeStream(ctx context.Context, s Stream) error {
|
|||
func HandlerServer(h Handler) StreamServer {
|
||||
return ServerFunc(func(ctx context.Context, s Stream) error {
|
||||
conn := NewConn(s)
|
||||
return conn.Run(ctx, h)
|
||||
conn.Go(ctx, h)
|
||||
<-conn.Done()
|
||||
return conn.Err()
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestTestServer(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
conn := test.connector.Connect(ctx)
|
||||
go conn.Run(ctx, jsonrpc2.MethodNotFound)
|
||||
conn.Go(ctx, jsonrpc2.MethodNotFound)
|
||||
var got msg
|
||||
if _, err := conn.Call(ctx, "ping", &msg{"ping"}, &got); err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -234,7 +234,7 @@ func (app *Application) connectRemote(ctx context.Context, remote string) (*conn
|
|||
cc := jsonrpc2.NewConn(stream)
|
||||
connection.Server = protocol.ServerDispatcher(cc)
|
||||
ctx = protocol.WithClient(ctx, connection.Client)
|
||||
go cc.Run(ctx,
|
||||
cc.Go(ctx,
|
||||
protocol.Handlers(
|
||||
protocol.ClientHandler(connection.Client,
|
||||
jsonrpc2.MethodNotFound)))
|
||||
|
|
|
@ -83,7 +83,7 @@ func NewEditor(ws *Sandbox, config EditorConfig) *Editor {
|
|||
func (e *Editor) Connect(ctx context.Context, conn *jsonrpc2.Conn, hooks ClientHooks) (*Editor, error) {
|
||||
e.Server = protocol.ServerDispatcher(conn)
|
||||
e.client = &Client{editor: e, hooks: hooks}
|
||||
go conn.Run(ctx,
|
||||
conn.Go(ctx,
|
||||
protocol.Handlers(
|
||||
protocol.ClientHandler(e.client,
|
||||
jsonrpc2.MethodNotFound)))
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/tools/internal/event"
|
||||
"golang.org/x/tools/internal/jsonrpc2"
|
||||
"golang.org/x/tools/internal/lsp"
|
||||
|
@ -140,11 +139,13 @@ func (s *StreamServer) ServeStream(ctx context.Context, stream jsonrpc2.Stream)
|
|||
executable = ""
|
||||
}
|
||||
ctx = protocol.WithClient(ctx, client)
|
||||
return conn.Run(ctx,
|
||||
conn.Go(ctx,
|
||||
protocol.Handlers(
|
||||
handshaker(dc, executable,
|
||||
protocol.ServerHandler(server,
|
||||
jsonrpc2.MethodNotFound))))
|
||||
<-conn.Done()
|
||||
return conn.Err()
|
||||
}
|
||||
|
||||
// A Forwarder is a jsonrpc2.StreamServer that handles an LSP stream by
|
||||
|
@ -234,7 +235,7 @@ func QueryServerState(ctx context.Context, network, address string) (*ServerStat
|
|||
return nil, fmt.Errorf("dialing remote: %w", err)
|
||||
}
|
||||
serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn, netConn))
|
||||
go serverConn.Run(ctx, jsonrpc2.MethodNotFound)
|
||||
serverConn.Go(ctx, jsonrpc2.MethodNotFound)
|
||||
var state ServerState
|
||||
if err := protocol.Call(ctx, serverConn, sessionsMethod, nil, &state); err != nil {
|
||||
return nil, fmt.Errorf("querying server state: %w", err)
|
||||
|
@ -256,13 +257,10 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
|
|||
server := protocol.ServerDispatcher(serverConn)
|
||||
|
||||
// Forward between connections.
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return serverConn.Run(ctx,
|
||||
serverConn.Go(ctx,
|
||||
protocol.Handlers(
|
||||
protocol.ClientHandler(client,
|
||||
jsonrpc2.MethodNotFound)))
|
||||
})
|
||||
// Don't run the clientConn yet, so that we can complete the handshake before
|
||||
// processing any client messages.
|
||||
|
||||
|
@ -298,15 +296,19 @@ func (f *Forwarder) ServeStream(ctx context.Context, stream jsonrpc2.Stream) err
|
|||
clientID: hresp.ClientID,
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
return clientConn.Run(ctx,
|
||||
clientConn.Go(ctx,
|
||||
protocol.Handlers(
|
||||
forwarderHandler(
|
||||
protocol.ServerHandler(server,
|
||||
jsonrpc2.MethodNotFound))))
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
<-serverConn.Done()
|
||||
<-clientConn.Done()
|
||||
err = serverConn.Err()
|
||||
if err == nil {
|
||||
err = clientConn.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) {
|
||||
|
|
|
@ -55,7 +55,7 @@ func TestClientLogging(t *testing.T) {
|
|||
ts := servertest.NewPipeServer(ctx, ss)
|
||||
defer checkClose(t, ts.Close)
|
||||
cc := ts.Connect(ctx)
|
||||
go cc.Run(ctx, protocol.ClientHandler(client, jsonrpc2.MethodNotFound))
|
||||
cc.Go(ctx, protocol.ClientHandler(client, jsonrpc2.MethodNotFound))
|
||||
|
||||
protocol.ServerDispatcher(cc).DidOpen(ctx, &protocol.DidOpenTextDocumentParams{})
|
||||
|
||||
|
@ -136,7 +136,7 @@ func TestRequestCancellation(t *testing.T) {
|
|||
t.Run(test.serverType, func(t *testing.T) {
|
||||
cc := test.ts.Connect(baseCtx)
|
||||
sd := protocol.ServerDispatcher(cc)
|
||||
go cc.Run(baseCtx,
|
||||
cc.Go(baseCtx,
|
||||
protocol.Handlers(
|
||||
jsonrpc2.MethodNotFound))
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче