internal/jsonrpc2: rewrite streams in terms of messages

messages are the atomic unit of communication, changing streams
to read and write whole messages makes the code clearer.
It also avoids the confusion about what should be an atomic
operation or when a stream should flush.

Change-Id: I4b731c9518ad7c2be92fc92211c33f32d809f38b
Reviewed-on: https://go-review.googlesource.com/c/tools/+/228722
Run-TryBot: Ian Cottrell <iancottrell@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
This commit is contained in:
Ian Cottrell 2020-04-16 21:49:42 -04:00
Родитель cfa8b22178
Коммит 2dc4334630
4 изменённых файлов: 40 добавлений и 62 удалений

Просмотреть файл

@ -8,7 +8,6 @@ package main
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/json"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@ -148,7 +147,7 @@ func send(ctx context.Context, l *parse.Logmsg, stream jsonrpc2.Stream, id *json
} }
id = jsonrpc2.NewIntID(int64(n)) id = jsonrpc2.NewIntID(int64(n))
} }
var msg interface{} var msg jsonrpc2.Message
var err error var err error
switch l.Type { switch l.Type {
case parse.ClRequest: case parse.ClRequest:
@ -163,11 +162,7 @@ func send(ctx context.Context, l *parse.Logmsg, stream jsonrpc2.Stream, id *json
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
data, err := json.Marshal(msg) stream.Write(ctx, msg)
if err != nil {
log.Fatal(err)
}
stream.Write(ctx, data)
} }
func respond(ctx context.Context, c *jsonrpc2.Call, stream jsonrpc2.Stream) { func respond(ctx context.Context, c *jsonrpc2.Call, stream jsonrpc2.Stream) {
@ -235,15 +230,11 @@ func mimic(ctx context.Context) {
rchan := make(chan jsonrpc2.Message, 10) // do we need buffering? rchan := make(chan jsonrpc2.Message, 10) // do we need buffering?
rdr := func() { rdr := func() {
for { for {
buf, _, err := stream.Read(ctx) msg, _, err := stream.Read(ctx)
if err != nil { if err != nil {
rchan <- nil // close it instead? rchan <- nil // close it instead?
return return
} }
msg, err := jsonrpc2.DecodeMessage(buf)
if err != nil {
log.Fatal(err)
}
rchan <- msg rchan <- msg
} }
} }

Просмотреть файл

@ -57,10 +57,6 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
if err != nil { if err != nil {
return fmt.Errorf("marshaling notify parameters: %v", err) return fmt.Errorf("marshaling notify parameters: %v", err)
} }
data, err := json.Marshal(notify)
if err != nil {
return fmt.Errorf("marshaling notify request: %v", err)
}
ctx, done := event.StartSpan(ctx, method, ctx, done := event.StartSpan(ctx, method,
tag.Method.Of(method), tag.Method.Of(method),
tag.RPCDirection.Of(tag.Outbound), tag.RPCDirection.Of(tag.Outbound),
@ -71,7 +67,7 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e
}() }()
event.Record(ctx, tag.Started.Of(1)) event.Record(ctx, tag.Started.Of(1))
n, err := c.stream.Write(ctx, data) n, err := c.stream.Write(ctx, notify)
event.Record(ctx, tag.SentBytes.Of(n)) event.Record(ctx, tag.SentBytes.Of(n))
return err return err
} }
@ -86,11 +82,6 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
if err != nil { if err != nil {
return id, fmt.Errorf("marshaling call parameters: %v", err) return id, fmt.Errorf("marshaling call parameters: %v", err)
} }
// marshal the request now it is complete
data, err := json.Marshal(call)
if err != nil {
return id, fmt.Errorf("marshaling call request: %v", err)
}
ctx, done := event.StartSpan(ctx, method, ctx, done := event.StartSpan(ctx, method,
tag.Method.Of(method), tag.Method.Of(method),
tag.RPCDirection.Of(tag.Outbound), tag.RPCDirection.Of(tag.Outbound),
@ -115,7 +106,7 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface
c.pendingMu.Unlock() c.pendingMu.Unlock()
}() }()
// now we are ready to send // now we are ready to send
n, err := c.stream.Write(ctx, data) n, err := c.stream.Write(ctx, call)
event.Record(ctx, tag.SentBytes.Of(n)) event.Record(ctx, tag.SentBytes.Of(n))
if err != nil { if err != nil {
// sending failed, we will never get a response, so don't leave it pending // sending failed, we will never get a response, so don't leave it pending
@ -155,10 +146,8 @@ func replier(conn *Conn, req Request, spanDone func()) Replier {
if err != nil { if err != nil {
return err return err
} }
data, err := json.Marshal(response) n, err := conn.stream.Write(ctx, response)
n, err := conn.stream.Write(ctx, data)
event.Record(ctx, tag.SentBytes.Of(n)) event.Record(ctx, tag.SentBytes.Of(n))
if err != nil { if err != nil {
// TODO(iancottrell): if a stream write fails, we really need to shut down // TODO(iancottrell): if a stream write fails, we really need to shut down
// the whole stream // the whole stream
@ -174,20 +163,13 @@ func replier(conn *Conn, req Request, spanDone func()) Replier {
// It returns only when the reader is closed or there is an error in the stream. // It returns only when the reader is closed or there is an error in the stream.
func (c *Conn) Run(runCtx context.Context, handler Handler) error { func (c *Conn) Run(runCtx context.Context, handler Handler) error {
for { for {
// get the data for a message // get the next message
data, n, err := c.stream.Read(runCtx) msg, n, err := c.stream.Read(runCtx)
if err != nil { if err != nil {
// The stream failed, we cannot continue. If the client disconnected // The stream failed, we cannot continue. If the client disconnected
// normally, we should get ErrDisconnected here. // normally, we should get ErrDisconnected here.
return err return err
} }
// read a combined message
msg, err := DecodeMessage(data)
if err != nil {
// a badly formed message arrived, log it and continue
// we trust the stream to have isolated the error to just this message
continue
}
switch msg := msg.(type) { switch msg := msg.(type) {
case Request: case Request:
tags := []event.Tag{ tags := []event.Tag{

Просмотреть файл

@ -22,10 +22,10 @@ import (
type Stream interface { type Stream interface {
// Read gets the next message from the stream. // Read gets the next message from the stream.
// It is never called concurrently. // It is never called concurrently.
Read(context.Context) ([]byte, int64, error) Read(context.Context) (Message, int64, error)
// Write sends a message to the stream. // Write sends a message to the stream.
// It must be safe for concurrent use. // It must be safe for concurrent use.
Write(context.Context, []byte) (int64, error) Write(context.Context, Message) (int64, error)
} }
// NewStream returns a Stream built on top of an io.Reader and io.Writer // NewStream returns a Stream built on top of an io.Reader and io.Writer
@ -44,7 +44,7 @@ type plainStream struct {
out io.Writer out io.Writer
} }
func (s *plainStream) Read(ctx context.Context) ([]byte, int64, error) { func (s *plainStream) Read(ctx context.Context) (Message, int64, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, 0, ctx.Err() return nil, 0, ctx.Err()
@ -57,15 +57,20 @@ func (s *plainStream) Read(ctx context.Context) ([]byte, int64, error) {
} }
return nil, 0, err return nil, 0, err
} }
return raw, int64(len(raw)), nil msg, err := DecodeMessage(raw)
return msg, int64(len(raw)), err
} }
func (s *plainStream) Write(ctx context.Context, data []byte) (int64, error) { func (s *plainStream) Write(ctx context.Context, msg Message) (int64, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ctx.Err() return 0, ctx.Err()
default: default:
} }
data, err := json.Marshal(msg)
if err != nil {
return 0, fmt.Errorf("marshaling message: %v", err)
}
s.outMu.Lock() s.outMu.Lock()
n, err := s.out.Write(data) n, err := s.out.Write(data)
s.outMu.Unlock() s.outMu.Unlock()
@ -88,7 +93,7 @@ type headerStream struct {
out io.Writer out io.Writer
} }
func (s *headerStream) Read(ctx context.Context) ([]byte, int64, error) { func (s *headerStream) Read(ctx context.Context) (Message, int64, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, 0, ctx.Err() return nil, 0, ctx.Err()
@ -136,15 +141,20 @@ func (s *headerStream) Read(ctx context.Context) ([]byte, int64, error) {
return nil, total, err return nil, total, err
} }
total += length total += length
return data, total, nil msg, err := DecodeMessage(data)
return msg, total, err
} }
func (s *headerStream) Write(ctx context.Context, data []byte) (int64, error) { func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ctx.Err() return 0, ctx.Err()
default: default:
} }
data, err := json.Marshal(msg)
if err != nil {
return 0, fmt.Errorf("marshaling message: %v", err)
}
s.outMu.Lock() s.outMu.Lock()
defer s.outMu.Unlock() defer s.outMu.Unlock()
n, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data)) n, err := fmt.Fprintf(s.out, "Content-Length: %v\r\n\r\n", len(data))

Просмотреть файл

@ -22,21 +22,21 @@ func LoggingStream(str jsonrpc2.Stream, w io.Writer) jsonrpc2.Stream {
return &loggingStream{stream: str, log: w} return &loggingStream{stream: str, log: w}
} }
func (s *loggingStream) Read(ctx context.Context) ([]byte, int64, error) { func (s *loggingStream) Read(ctx context.Context) (jsonrpc2.Message, int64, error) {
data, count, err := s.stream.Read(ctx) msg, count, err := s.stream.Read(ctx)
if err == nil { if err == nil {
s.logMu.Lock() s.logMu.Lock()
defer s.logMu.Unlock() defer s.logMu.Unlock()
logIn(s.log, data) logIn(s.log, msg)
} }
return data, count, err return msg, count, err
} }
func (s *loggingStream) Write(ctx context.Context, data []byte) (int64, error) { func (s *loggingStream) Write(ctx context.Context, msg jsonrpc2.Message) (int64, error) {
s.logMu.Lock() s.logMu.Lock()
defer s.logMu.Unlock() defer s.logMu.Unlock()
logOut(s.log, data) logOut(s.log, msg)
count, err := s.stream.Write(ctx, data) count, err := s.stream.Write(ctx, msg)
return count, err return count, err
} }
@ -94,26 +94,21 @@ func (m *mapped) setServer(id string, r req) {
const eor = "\r\n\r\n\r\n" const eor = "\r\n\r\n\r\n"
func logCommon(outfd io.Writer, data []byte) (jsonrpc2.Message, time.Time, string) { func logCommon(outfd io.Writer) (time.Time, string) {
if outfd == nil { if outfd == nil {
return nil, time.Time{}, "" return time.Time{}, ""
}
v, err := jsonrpc2.DecodeMessage(data)
if err != nil {
fmt.Fprintf(outfd, "Unmarshal %v\n", err)
panic(err) // do better
} }
tm := time.Now() tm := time.Now()
tmfmt := tm.Format("15:04:05.000 PM") tmfmt := tm.Format("15:04:05.000 PM")
return v, tm, tmfmt return tm, tmfmt
} }
// logOut and logIn could be combined. "received"<->"Sending", serverCalls<->clientCalls // logOut and logIn could be combined. "received"<->"Sending", serverCalls<->clientCalls
// but it wouldn't be a lot shorter or clearer and "shutdown" is a special case // but it wouldn't be a lot shorter or clearer and "shutdown" is a special case
// Writing a message to the client, log it // Writing a message to the client, log it
func logOut(outfd io.Writer, data []byte) { func logOut(outfd io.Writer, msg jsonrpc2.Message) {
msg, tm, tmfmt := logCommon(outfd, data) tm, tmfmt := logCommon(outfd)
if msg == nil { if msg == nil {
return return
} }
@ -145,8 +140,8 @@ func logOut(outfd io.Writer, data []byte) {
} }
// Got a message from the client, log it // Got a message from the client, log it
func logIn(outfd io.Writer, data []byte) { func logIn(outfd io.Writer, msg jsonrpc2.Message) {
msg, tm, tmfmt := logCommon(outfd, data) tm, tmfmt := logCommon(outfd)
if msg == nil { if msg == nil {
return return
} }