diff --git a/gopls/integration/replay/main.go b/gopls/integration/replay/main.go index 0a1ff6b8a..49486ff72 100644 --- a/gopls/integration/replay/main.go +++ b/gopls/integration/replay/main.go @@ -8,7 +8,6 @@ package main import ( "bufio" "context" - "encoding/json" "flag" "fmt" "log" @@ -148,7 +147,7 @@ func send(ctx context.Context, l *parse.Logmsg, stream jsonrpc2.Stream, id *json } id = jsonrpc2.NewIntID(int64(n)) } - var msg interface{} + var msg jsonrpc2.Message var err error switch l.Type { case parse.ClRequest: @@ -163,11 +162,7 @@ func send(ctx context.Context, l *parse.Logmsg, stream jsonrpc2.Stream, id *json if err != nil { log.Fatal(err) } - data, err := json.Marshal(msg) - if err != nil { - log.Fatal(err) - } - stream.Write(ctx, data) + stream.Write(ctx, msg) } func respond(ctx context.Context, c *jsonrpc2.Call, stream jsonrpc2.Stream) { @@ -235,15 +230,11 @@ func mimic(ctx context.Context) { rchan := make(chan jsonrpc2.Message, 10) // do we need buffering? rdr := func() { for { - buf, _, err := stream.Read(ctx) + msg, _, err := stream.Read(ctx) if err != nil { rchan <- nil // close it instead? return } - msg, err := jsonrpc2.DecodeMessage(buf) - if err != nil { - log.Fatal(err) - } rchan <- msg } } diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index b7364fa31..f340dc5c2 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -57,10 +57,6 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e if err != nil { return fmt.Errorf("marshaling notify parameters: %v", err) } - data, err := json.Marshal(notify) - if err != nil { - return fmt.Errorf("marshaling notify request: %v", err) - } ctx, done := event.StartSpan(ctx, method, tag.Method.Of(method), tag.RPCDirection.Of(tag.Outbound), @@ -71,7 +67,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e }() event.Record(ctx, tag.Started.Of(1)) - n, err := c.stream.Write(ctx, data) + n, err := c.stream.Write(ctx, notify) event.Record(ctx, tag.SentBytes.Of(n)) return err } @@ -86,11 +82,6 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface if err != nil { return id, fmt.Errorf("marshaling call parameters: %v", err) } - // marshal the request now it is complete - data, err := json.Marshal(call) - if err != nil { - return id, fmt.Errorf("marshaling call request: %v", err) - } ctx, done := event.StartSpan(ctx, method, tag.Method.Of(method), tag.RPCDirection.Of(tag.Outbound), @@ -115,7 +106,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface c.pendingMu.Unlock() }() // now we are ready to send - n, err := c.stream.Write(ctx, data) + n, err := c.stream.Write(ctx, call) event.Record(ctx, tag.SentBytes.Of(n)) if err != nil { // sending failed, we will never get a response, so don't leave it pending @@ -155,10 +146,8 @@ func replier(conn *Conn, req Request, spanDone func()) Replier { if err != nil { return err } - data, err := json.Marshal(response) - n, err := conn.stream.Write(ctx, data) + n, err := conn.stream.Write(ctx, response) event.Record(ctx, tag.SentBytes.Of(n)) - if err != nil { // TODO(iancottrell): if a stream write fails, we really need to shut down // the whole stream @@ -174,20 +163,13 @@ func replier(conn *Conn, req Request, spanDone func()) Replier { // It returns only when the reader is closed or there is an error in the stream. func (c *Conn) Run(runCtx context.Context, handler Handler) error { for { - // get the data for a message - data, n, err := c.stream.Read(runCtx) + // get the next message + msg, n, err := c.stream.Read(runCtx) if err != nil { // The stream failed, we cannot continue. If the client disconnected // normally, we should get ErrDisconnected here. return err } - // read a combined message - msg, err := DecodeMessage(data) - if err != nil { - // a badly formed message arrived, log it and continue - // we trust the stream to have isolated the error to just this message - continue - } switch msg := msg.(type) { case Request: tags := []event.Tag{ diff --git a/internal/jsonrpc2/stream.go b/internal/jsonrpc2/stream.go index 3276c1ac4..dc2ebd4a3 100644 --- a/internal/jsonrpc2/stream.go +++ b/internal/jsonrpc2/stream.go @@ -22,10 +22,10 @@ import ( type Stream interface { // Read gets the next message from the stream. // It is never called concurrently. - Read(context.Context) ([]byte, int64, error) + Read(context.Context) (Message, int64, error) // Write sends a message to the stream. // It must be safe for concurrent use. - Write(context.Context, []byte) (int64, error) + Write(context.Context, Message) (int64, error) } // NewStream returns a Stream built on top of an io.Reader and io.Writer @@ -44,7 +44,7 @@ type plainStream struct { out io.Writer } -func (s *plainStream) Read(ctx context.Context) ([]byte, int64, error) { +func (s *plainStream) Read(ctx context.Context) (Message, int64, error) { select { case <-ctx.Done(): return nil, 0, ctx.Err() @@ -57,15 +57,20 @@ func (s *plainStream) Read(ctx context.Context) ([]byte, int64, error) { } return nil, 0, err } - return raw, int64(len(raw)), nil + msg, err := DecodeMessage(raw) + return msg, int64(len(raw)), err } -func (s *plainStream) Write(ctx context.Context, data []byte) (int64, error) { +func (s *plainStream) Write(ctx context.Context, msg Message) (int64, error) { select { case <-ctx.Done(): return 0, ctx.Err() default: } + data, err := json.Marshal(msg) + if err != nil { + return 0, fmt.Errorf("marshaling message: %v", err) + } s.outMu.Lock() n, err := s.out.Write(data) s.outMu.Unlock() @@ -88,7 +93,7 @@ type headerStream struct { out io.Writer } -func (s *headerStream) Read(ctx context.Context) ([]byte, int64, error) { +func (s *headerStream) Read(ctx context.Context) (Message, int64, error) { select { case <-ctx.Done(): return nil, 0, ctx.Err() @@ -136,15 +141,20 @@ func (s *headerStream) Read(ctx context.Context) ([]byte, int64, error) { return nil, total, err } total += length - return data, total, nil + msg, err := DecodeMessage(data) + return msg, total, err } -func (s *headerStream) Write(ctx context.Context, data []byte) (int64, error) { +func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) { select { case <-ctx.Done(): return 0, ctx.Err() default: } + data, err := json.Marshal(msg) + if err != nil { + return 0, fmt.Errorf("marshaling message: %v", err) + } s.outMu.Lock() defer s.outMu.Unlock() n, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data)) diff --git a/internal/lsp/protocol/log.go b/internal/lsp/protocol/log.go index afc037c95..dfa4b6982 100644 --- a/internal/lsp/protocol/log.go +++ b/internal/lsp/protocol/log.go @@ -22,21 +22,21 @@ func LoggingStream(str jsonrpc2.Stream, w io.Writer) jsonrpc2.Stream { return &loggingStream{stream: str, log: w} } -func (s *loggingStream) Read(ctx context.Context) ([]byte, int64, error) { - data, count, err := s.stream.Read(ctx) +func (s *loggingStream) Read(ctx context.Context) (jsonrpc2.Message, int64, error) { + msg, count, err := s.stream.Read(ctx) if err == nil { s.logMu.Lock() defer s.logMu.Unlock() - logIn(s.log, data) + logIn(s.log, msg) } - return data, count, err + return msg, count, err } -func (s *loggingStream) Write(ctx context.Context, data []byte) (int64, error) { +func (s *loggingStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) { s.logMu.Lock() defer s.logMu.Unlock() - logOut(s.log, data) - count, err := s.stream.Write(ctx, data) + logOut(s.log, msg) + count, err := s.stream.Write(ctx, msg) return count, err } @@ -94,26 +94,21 @@ func (m *mapped) setServer(id string, r req) { const eor = "\r\n\r\n\r\n" -func logCommon(outfd io.Writer, data []byte) (jsonrpc2.Message, time.Time, string) { +func logCommon(outfd io.Writer) (time.Time, string) { if outfd == nil { - return nil, time.Time{}, "" - } - v, err := jsonrpc2.DecodeMessage(data) - if err != nil { - fmt.Fprintf(outfd, "Unmarshal %v\n", err) - panic(err) // do better + return time.Time{}, "" } tm := time.Now() tmfmt := tm.Format("15:04:05.000 PM") - return v, tm, tmfmt + return tm, tmfmt } // logOut and logIn could be combined. "received"<->"Sending", serverCalls<->clientCalls // but it wouldn't be a lot shorter or clearer and "shutdown" is a special case // Writing a message to the client, log it -func logOut(outfd io.Writer, data []byte) { - msg, tm, tmfmt := logCommon(outfd, data) +func logOut(outfd io.Writer, msg jsonrpc2.Message) { + tm, tmfmt := logCommon(outfd) if msg == nil { return } @@ -145,8 +140,8 @@ func logOut(outfd io.Writer, data []byte) { } // Got a message from the client, log it -func logIn(outfd io.Writer, data []byte) { - msg, tm, tmfmt := logCommon(outfd, data) +func logIn(outfd io.Writer, msg jsonrpc2.Message) { + tm, tmfmt := logCommon(outfd) if msg == nil { return }