diff --git a/gopls/internal/cmd/cmd.go b/gopls/internal/cmd/cmd.go index 4afac6a7a..91aca4683 100644 --- a/gopls/internal/cmd/cmd.go +++ b/gopls/internal/cmd/cmd.go @@ -12,7 +12,6 @@ import ( "flag" "fmt" "log" - "math/rand" "os" "path/filepath" "reflect" @@ -391,35 +390,13 @@ type connection struct { client *cmdClient } -// registerProgressHandler registers a handler for progress notifications. -// The caller must call unregister when the handler is no longer needed. -func (cli *cmdClient) registerProgressHandler(handler func(*protocol.ProgressParams)) (token protocol.ProgressToken, unregister func()) { - token = fmt.Sprintf("tok%d", rand.Uint64()) - - // register - cli.progressHandlersMu.Lock() - if cli.progressHandlers == nil { - cli.progressHandlers = make(map[protocol.ProgressToken]func(*protocol.ProgressParams)) - } - cli.progressHandlers[token] = handler - cli.progressHandlersMu.Unlock() - - unregister = func() { - cli.progressHandlersMu.Lock() - delete(cli.progressHandlers, token) - cli.progressHandlersMu.Unlock() - } - return token, unregister -} - // cmdClient defines the protocol.Client interface behavior of the gopls CLI tool. type cmdClient struct { app *Application - progressHandlersMu sync.Mutex - progressHandlers map[protocol.ProgressToken]func(*protocol.ProgressParams) - iwlToken protocol.ProgressToken - iwlDone chan struct{} + progressMu sync.Mutex + iwlToken protocol.ProgressToken + iwlDone chan struct{} filesMu sync.Mutex // guards files map files map[protocol.DocumentURI]*cmdFile @@ -698,41 +675,33 @@ func (c *cmdClient) PublishDiagnostics(ctx context.Context, p *protocol.PublishD } func (c *cmdClient) Progress(_ context.Context, params *protocol.ProgressParams) error { - token, ok := params.Token.(string) - if !ok { + if _, ok := params.Token.(string); !ok { return fmt.Errorf("unexpected progress token: %[1]T %[1]v", params.Token) } - c.progressHandlersMu.Lock() - handler := c.progressHandlers[token] - c.progressHandlersMu.Unlock() - if handler == nil { - handler = c.defaultProgressHandler - } - handler(params) - return nil -} - -// defaultProgressHandler is the default handler of progress messages, -// used during the initialize request. -func (c *cmdClient) defaultProgressHandler(params *protocol.ProgressParams) { switch v := params.Value.(type) { case *protocol.WorkDoneProgressBegin: if v.Title == server.DiagnosticWorkTitle(server.FromInitialWorkspaceLoad) { - c.progressHandlersMu.Lock() + c.progressMu.Lock() c.iwlToken = params.Token - c.progressHandlersMu.Unlock() + c.progressMu.Unlock() + } + + case *protocol.WorkDoneProgressReport: + if c.app.Verbose { + fmt.Fprintln(os.Stderr, v.Message) } case *protocol.WorkDoneProgressEnd: - c.progressHandlersMu.Lock() + c.progressMu.Lock() iwlToken := c.iwlToken - c.progressHandlersMu.Unlock() + c.progressMu.Unlock() if params.Token == iwlToken { close(c.iwlDone) } } + return nil } func (c *cmdClient) ShowDocument(ctx context.Context, params *protocol.ShowDocumentParams) (*protocol.ShowDocumentResult, error) { diff --git a/gopls/internal/cmd/execute.go b/gopls/internal/cmd/execute.go index 96b3cf3b8..967e97ed5 100644 --- a/gopls/internal/cmd/execute.go +++ b/gopls/internal/cmd/execute.go @@ -10,12 +10,10 @@ import ( "flag" "fmt" "log" - "os" "slices" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" - "golang.org/x/tools/gopls/internal/server" "golang.org/x/tools/internal/tool" ) @@ -98,38 +96,11 @@ func (e *execute) Run(ctx context.Context, args ...string) error { // executeCommand executes a protocol.Command, displaying progress // messages and awaiting completion of asynchronous commands. +// +// TODO(rfindley): inline away all calls, ensuring they inline idiomatically. func (conn *connection) executeCommand(ctx context.Context, cmd *protocol.Command) (any, error) { - endStatus := make(chan string, 1) - token, unregister := conn.client.registerProgressHandler(func(params *protocol.ProgressParams) { - switch v := params.Value.(type) { - case *protocol.WorkDoneProgressReport: - fmt.Fprintln(os.Stderr, v.Message) // combined std{out,err} - - case *protocol.WorkDoneProgressEnd: - endStatus <- v.Message // = canceled | failed | completed - } - }) - defer unregister() - - res, err := conn.ExecuteCommand(ctx, &protocol.ExecuteCommandParams{ + return conn.ExecuteCommand(ctx, &protocol.ExecuteCommandParams{ Command: cmd.Command, Arguments: cmd.Arguments, - WorkDoneProgressParams: protocol.WorkDoneProgressParams{ - WorkDoneToken: token, - }, }) - if err != nil { - return nil, err - } - - // Some commands are asynchronous, so clients - // must wait for the "end" progress notification. - if command.Command(cmd.Command).IsAsync() { - status := <-endStatus - if status != server.CommandCompleted { - return nil, fmt.Errorf("command %s", status) - } - } - - return res, nil } diff --git a/gopls/internal/cmd/integration_test.go b/gopls/internal/cmd/integration_test.go index 39698f373..15888b21f 100644 --- a/gopls/internal/cmd/integration_test.go +++ b/gopls/internal/cmd/integration_test.go @@ -224,7 +224,7 @@ func TestFail(t *testing.T) { t.Fatal("fail") } } // run the passing test { - res := gopls(t, tree, "codelens", "-exec", "./a/a_test.go:3", "run test") + res := gopls(t, tree, "-v", "codelens", "-exec", "./a/a_test.go:3", "run test") res.checkExit(true) res.checkStderr(`PASS: TestPass`) // from go test res.checkStderr("Info: all tests passed") // from gopls.test diff --git a/gopls/internal/protocol/command/interface.go b/gopls/internal/protocol/command/interface.go index 0838e930c..eda608a84 100644 --- a/gopls/internal/protocol/command/interface.go +++ b/gopls/internal/protocol/command/interface.go @@ -503,7 +503,15 @@ type VulncheckArgs struct { type RunVulncheckResult struct { // Token holds the progress token for LSP workDone reporting of the vulncheck // invocation. + // + // Deprecated: previously, this was used as a signal to retrieve the result + // using gopls.fetch_vulncheck_result. Clients should ignore this field: + // gopls.vulncheck now runs synchronously, and returns a result in the Result + // field. Token protocol.ProgressToken + + // Result holds the result of running vulncheck. + Result *vulncheck.Result } // MemStatsResult holds selected fields from runtime.MemStats. diff --git a/gopls/internal/protocol/command/util.go b/gopls/internal/protocol/command/util.go index 7cd5662e3..d07cd863f 100644 --- a/gopls/internal/protocol/command/util.go +++ b/gopls/internal/protocol/command/util.go @@ -15,18 +15,6 @@ type Command string func (c Command) String() string { return string(c) } -// IsAsync reports whether the command is asynchronous: -// clients must wait for the "end" progress notification. -func (c Command) IsAsync() bool { - switch string(c) { - // TODO(adonovan): derive this list from interface.go somewhow. - // Unfortunately we can't even reference the enum from here... - case "gopls.run_tests", "gopls.run_govulncheck", "gopls.test": - return true - } - return false -} - // MarshalArgs encodes the given arguments to json.RawMessages. This function // is used to construct arguments to a protocol.Command. // diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index 4f6f24d86..bfc8f9c55 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "log" "os" "path/filepath" "regexp" @@ -41,6 +40,7 @@ import ( "golang.org/x/tools/internal/diff" "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/gocommand" + "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/tokeninternal" "golang.org/x/tools/internal/xcontext" ) @@ -278,7 +278,7 @@ func (*commandHandler) AddTelemetryCounters(_ context.Context, args command.AddT // commandConfig configures common command set-up and execution. type commandConfig struct { requireSave bool // whether all files must be saved for the command to work - progress string // title to use for progress reporting. If empty, no progress will be reported. Mandatory for async commands. + progress string // title to use for progress reporting. If empty, no progress will be reported. forView string // view to resolve to a snapshot; incompatible with forURI forURI protocol.DocumentURI // URI to resolve to a snapshot. If unset, snapshot will be nil. } @@ -370,18 +370,6 @@ func (c *commandHandler) run(ctx context.Context, cfg commandConfig, run command return err } - if enum := command.Command(c.params.Command); enum.IsAsync() { - if cfg.progress == "" { - log.Fatalf("asynchronous command %q does not enable progress reporting", - enum) - } - go func() { - if err := runcmd(); err != nil { - showMessage(ctx, c.s.client, protocol.Error, err.Error()) - } - }() - return nil - } return runcmd() } @@ -725,6 +713,7 @@ func (c *commandHandler) RunTests(ctx context.Context, args command.RunTestsArgs requireSave: true, // go test honors overlays, but tests themselves cannot forURI: args.URI, }, func(ctx context.Context, deps commandDeps) error { + jsonrpc2.Async(ctx) // don't block RPCs behind this command, since it can take a while return c.runTests(ctx, deps.snapshot, deps.work, args.URI, args.Tests, args.Benchmarks) }) } @@ -1233,23 +1222,25 @@ func (c *commandHandler) FetchVulncheckResult(ctx context.Context, arg command.U return ret, err } +const GoVulncheckCommandTitle = "govulncheck" + func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.VulncheckArgs) (command.RunVulncheckResult, error) { if args.URI == "" { return command.RunVulncheckResult{}, errors.New("VulncheckArgs is missing URI field") } - // Return the workdone token so that clients can identify when this - // vulncheck invocation is complete. - // - // Since the run function executes asynchronously, we use a channel to - // synchronize the start of the run and return the token. - tokenChan := make(chan protocol.ProgressToken, 1) + var commandResult command.RunVulncheckResult err := c.run(ctx, commandConfig{ - progress: "govulncheck", // (asynchronous) - requireSave: true, // govulncheck cannot honor overlays + progress: GoVulncheckCommandTitle, + requireSave: true, // govulncheck cannot honor overlays forURI: args.URI, }, func(ctx context.Context, deps commandDeps) error { - tokenChan <- deps.work.Token() + // For compatibility with the legacy asynchronous API, return the workdone + // token that clients used to use to identify when this vulncheck + // invocation is complete. + commandResult.Token = deps.work.Token() + + jsonrpc2.Async(ctx) // run this in parallel with other requests: vulncheck can be slow. workDoneWriter := progress.NewWorkDoneWriter(ctx, deps.work) dir := filepath.Dir(args.URI.Path()) @@ -1259,6 +1250,7 @@ func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.Vulnch if err != nil { return err } + commandResult.Result = result snapshot, release, err := c.s.session.InvalidateView(ctx, deps.snapshot.View(), cache.StateChange{ Vulns: map[protocol.DocumentURI]*vulncheck.Result{args.URI: result}, @@ -1295,12 +1287,7 @@ func (c *commandHandler) RunGovulncheck(ctx context.Context, args command.Vulnch if err != nil { return command.RunVulncheckResult{}, err } - select { - case <-ctx.Done(): - return command.RunVulncheckResult{}, ctx.Err() - case token := <-tokenChan: - return command.RunVulncheckResult{Token: token}, nil - } + return commandResult, nil } // MemStats implements the MemStats command. It returns an error as a diff --git a/gopls/internal/test/integration/codelens/codelens_test.go b/gopls/internal/test/integration/codelens/codelens_test.go index 75b9fda1f..c4711135d 100644 --- a/gopls/internal/test/integration/codelens/codelens_test.go +++ b/gopls/internal/test/integration/codelens/codelens_test.go @@ -182,10 +182,10 @@ require golang.org/x/hello v1.2.3 if !found { t.Fatalf("found no command with the title %s", commandTitle) } - if _, err := env.Editor.ExecuteCommand(env.Ctx, &protocol.ExecuteCommandParams{ + if err := env.Editor.ExecuteCommand(env.Ctx, &protocol.ExecuteCommandParams{ Command: lens.Command.Command, Arguments: lens.Command.Arguments, - }); err != nil { + }, nil); err != nil { t.Fatal(err) } env.AfterChange() diff --git a/gopls/internal/test/integration/expectation.go b/gopls/internal/test/integration/expectation.go index 858daeee1..f68f1de5e 100644 --- a/gopls/internal/test/integration/expectation.go +++ b/gopls/internal/test/integration/expectation.go @@ -452,17 +452,27 @@ type WorkStatus struct { EndMsg string } -// CompletedProgress expects that workDone progress is complete for the given -// progress token. When non-nil WorkStatus is provided, it will be filled -// when the expectation is met. +// CompletedProgress expects that there is exactly one workDone progress with +// the given title, and is satisfied when that progress completes. If it is +// met, the corresponding status is written to the into argument. // -// If the token is not a progress token that the client has seen, this -// expectation is Unmeetable. -func CompletedProgress(token protocol.ProgressToken, into *WorkStatus) Expectation { +// TODO(rfindley): refactor to eliminate the redundancy with CompletedWork. +// This expectation is a vestige of older workarounds for asynchronous command +// execution. +func CompletedProgress(title string, into *WorkStatus) Expectation { check := func(s State) Verdict { - work, ok := s.work[token] - if !ok { - return Unmeetable // TODO(rfindley): refactor to allow the verdict to explain this result + var work *workProgress + for _, w := range s.work { + if w.title == title { + if work != nil { + // TODO(rfindley): refactor to allow the verdict to explain this result + return Unmeetable // multiple matches + } + work = w + } + } + if work == nil { + return Unmeetable // zero matches } if work.complete { if into != nil { @@ -473,7 +483,7 @@ func CompletedProgress(token protocol.ProgressToken, into *WorkStatus) Expectati } return Unmet } - desc := fmt.Sprintf("completed work for token %v", token) + desc := fmt.Sprintf("exactly 1 completed workDoneProgress with title %v", title) return Expectation{ Check: check, Description: desc, diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index 041891aaa..466e833f2 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -1014,10 +1014,10 @@ func (e *Editor) ApplyCodeAction(ctx context.Context, action protocol.CodeAction // Execute any commands. The specification says that commands are // executed after edits are applied. if action.Command != nil { - if _, err := e.ExecuteCommand(ctx, &protocol.ExecuteCommandParams{ + if err := e.ExecuteCommand(ctx, &protocol.ExecuteCommandParams{ Command: action.Command.Command, Arguments: action.Command.Arguments, - }); err != nil { + }, nil); err != nil { return err } } @@ -1084,6 +1084,8 @@ func (e *Editor) applyCodeActions(ctx context.Context, loc protocol.Location, di return applied, nil } +// TODO(rfindley): add missing documentation to exported methods here. + func (e *Editor) CodeActions(ctx context.Context, loc protocol.Location, diagnostics []protocol.Diagnostic, only ...protocol.CodeActionKind) ([]protocol.CodeAction, error) { if e.Server == nil { return nil, nil @@ -1098,9 +1100,35 @@ func (e *Editor) CodeActions(ctx context.Context, loc protocol.Location, diagnos return e.Server.CodeAction(ctx, params) } -func (e *Editor) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCommandParams) (interface{}, error) { +func (e *Editor) ExecuteCodeLensCommand(ctx context.Context, path string, cmd command.Command, result any) error { + lenses, err := e.CodeLens(ctx, path) + if err != nil { + return err + } + var lens protocol.CodeLens + var found bool + for _, l := range lenses { + if l.Command.Command == cmd.String() { + lens = l + found = true + } + } + if !found { + return fmt.Errorf("found no command with the ID %s", cmd) + } + return e.ExecuteCommand(ctx, &protocol.ExecuteCommandParams{ + Command: lens.Command.Command, + Arguments: lens.Command.Arguments, + }, result) +} + +// ExecuteCommand makes a workspace/executeCommand request to the connected LSP +// server, if any. +// +// Result contains a pointer to a variable to be populated by json.Unmarshal. +func (e *Editor) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCommandParams, result any) error { if e.Server == nil { - return nil, nil + return nil } var match bool if e.serverCapabilities.ExecuteCommandProvider != nil { @@ -1113,18 +1141,37 @@ func (e *Editor) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCom } } if !match { - return nil, fmt.Errorf("unsupported command %q", params.Command) + return fmt.Errorf("unsupported command %q", params.Command) } - result, err := e.Server.ExecuteCommand(ctx, params) + response, err := e.Server.ExecuteCommand(ctx, params) if err != nil { - return nil, err + return err } // Some commands use the go command, which writes directly to disk. // For convenience, check for those changes. if err := e.sandbox.Workdir.CheckForFileChanges(ctx); err != nil { - return nil, fmt.Errorf("checking for file changes: %v", err) + return fmt.Errorf("checking for file changes: %v", err) } - return result, nil + if result != nil { + // ExecuteCommand already unmarshalled the response without knowing + // its schema, using the generic map[string]any representation. + // Encode and decode again, this time into a typed variable. + // + // This could be improved by generating a jsonrpc2 command client from the + // command.Interface, but that should only be done if we're consolidating + // this part of the tsprotocol generation. + // + // TODO(rfindley): we could also improve this by having ExecuteCommand return + // a json.RawMessage, similar to what we do with arguments. + data, err := json.Marshal(response) + if err != nil { + return bug.Errorf("marshalling response: %v", err) + } + if err := json.Unmarshal(data, result); err != nil { + return fmt.Errorf("unmarshalling response: %v", err) + } + } + return nil } // FormatBuffer gofmts a Go file. @@ -1183,7 +1230,7 @@ func (e *Editor) RunGenerate(ctx context.Context, dir string) error { Command: cmd.Command, Arguments: cmd.Arguments, } - if _, err := e.ExecuteCommand(ctx, params); err != nil { + if err := e.ExecuteCommand(ctx, params, nil); err != nil { return fmt.Errorf("running generate: %v", err) } // Unfortunately we can't simply poll the workdir for file changes here, diff --git a/gopls/internal/test/integration/fake/workdir.go b/gopls/internal/test/integration/fake/workdir.go index be3cb3bcf..54fabb358 100644 --- a/gopls/internal/test/integration/fake/workdir.go +++ b/gopls/internal/test/integration/fake/workdir.go @@ -73,7 +73,7 @@ func writeFileData(path string, content []byte, rel RelativeTo) error { // isWindowsErrLockViolation reports whether err is ERROR_LOCK_VIOLATION // on Windows. -var isWindowsErrLockViolation = func(err error) bool { return false } +var isWindowsErrLockViolation = func(error) bool { return false } // Workdir is a temporary working directory for tests. It exposes file // operations in terms of relative paths, and fakes file watching by triggering diff --git a/gopls/internal/test/integration/misc/vuln_test.go b/gopls/internal/test/integration/misc/vuln_test.go index 7be02b3ce..05cdbe859 100644 --- a/gopls/internal/test/integration/misc/vuln_test.go +++ b/gopls/internal/test/integration/misc/vuln_test.go @@ -17,6 +17,7 @@ import ( "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/protocol" "golang.org/x/tools/gopls/internal/protocol/command" + "golang.org/x/tools/gopls/internal/server" "golang.org/x/tools/gopls/internal/test/compare" . "golang.org/x/tools/gopls/internal/test/integration" "golang.org/x/tools/gopls/internal/vulncheck" @@ -41,10 +42,11 @@ package foo Arguments: cmd.Arguments, } - response, err := env.Editor.ExecuteCommand(env.Ctx, params) + var result any + err := env.Editor.ExecuteCommand(env.Ctx, params, &result) // We want an error! if err == nil { - t.Errorf("got success, want invalid file URL error: %v", response) + t.Errorf("got success, want invalid file URL error. Result: %v", result) } }) } @@ -72,13 +74,16 @@ func F() { // build error incomplete ).Run(t, files, func(t *testing.T, env *Env) { env.OpenFile("go.mod") var result command.RunVulncheckResult - env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) + err := env.Editor.ExecuteCodeLensCommand(env.Ctx, "go.mod", command.RunGovulncheck, &result) + if err == nil { + t.Fatalf("govulncheck succeeded unexpectedly: %v", result) + } var ws WorkStatus env.Await( - CompletedProgress(result.Token, &ws), + CompletedProgress(server.GoVulncheckCommandTitle, &ws), ) wantEndMsg, wantMsgPart := "failed", "There are errors with the provided package patterns:" - if ws.EndMsg != "failed" || !strings.Contains(ws.Msg, wantMsgPart) { + if ws.EndMsg != "failed" || !strings.Contains(ws.Msg, wantMsgPart) || !strings.Contains(err.Error(), wantMsgPart) { t.Errorf("work status = %+v, want {EndMessage: %q, Message: %q}", ws, wantEndMsg, wantMsgPart) } }) @@ -203,14 +208,16 @@ func main() { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) env.OnceMet( - CompletedProgress(result.Token, nil), + CompletedProgress(server.GoVulncheckCommandTitle, nil), ShownMessage("Found GOSTDLIB"), NoDiagnostics(ForFile("go.mod")), ) - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{ - "go.mod": {IDs: []string{"GOSTDLIB"}, Mode: vulncheck.ModeGovulncheck}}) + testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ + "go.mod": {IDs: []string{"GOSTDLIB"}, Mode: vulncheck.ModeGovulncheck}, + }) }) } + func TestFetchVulncheckResultStd(t *testing.T) { const files = ` -- go.mod -- @@ -252,7 +259,7 @@ func main() { NoDiagnostics(ForFile("go.mod")), // we don't publish diagnostics for standard library vulnerability yet. ) - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{ + testFetchVulncheckResult(t, env, "", nil, map[string]fetchVulncheckResult{ "go.mod": { IDs: []string{"GOSTDLIB"}, Mode: vulncheck.ModeImports, @@ -261,12 +268,28 @@ func main() { }) } +// fetchVulncheckResult summarizes a vulncheck result for a single file. type fetchVulncheckResult struct { IDs []string Mode vulncheck.AnalysisMode } -func testFetchVulncheckResult(t *testing.T, env *Env, want map[string]fetchVulncheckResult) { +// testFetchVulncheckResult checks that calling gopls.fetch_vulncheck_result +// returns the expected summarized results contained in the want argument. +// +// If fromRun is non-nil, is is the result of running running vulncheck for +// runPath, and testFetchVulncheckResult also checks that the fetched result +// for runPath matches fromRun. +// +// This awkward factoring is an artifact of a transition from fetching +// vulncheck results asynchronously, to allowing the command to run +// asynchronously, yet returning the result synchronously from the client's +// perspective. +// +// TODO(rfindley): once VS Code no longer depends on fetching results +// asynchronously, we can remove gopls.fetch_vulncheck_result, and simplify or +// remove this helper. +func testFetchVulncheckResult(t *testing.T, env *Env, runPath string, fromRun *vulncheck.Result, want map[string]fetchVulncheckResult) { t.Helper() var result map[protocol.DocumentURI]*vulncheck.Result @@ -281,8 +304,7 @@ func testFetchVulncheckResult(t *testing.T, env *Env, want map[string]fetchVulnc for _, v := range want { sort.Strings(v.IDs) } - got := map[string]fetchVulncheckResult{} - for k, r := range result { + summarize := func(r *vulncheck.Result) fetchVulncheckResult { osv := map[string]bool{} for _, v := range r.Findings { osv[v.OSV] = true @@ -292,14 +314,23 @@ func testFetchVulncheckResult(t *testing.T, env *Env, want map[string]fetchVulnc ids = append(ids, id) } sort.Strings(ids) - modfile := env.Sandbox.Workdir.RelPath(k.Path()) - got[modfile] = fetchVulncheckResult{ + return fetchVulncheckResult{ IDs: ids, Mode: r.Mode, } } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("fetch vulnchheck result = got %v, want %v: diff %v", got, want, diff) + got := map[string]fetchVulncheckResult{} + for k, r := range result { + modfile := env.Sandbox.Workdir.RelPath(k.Path()) + got[modfile] = summarize(r) + } + if fromRun != nil { + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("fetch vulncheck result = got %v, want %v: diff %v", got, want, diff) + } + if diff := cmp.Diff(summarize(fromRun), got[runPath]); diff != "" { + t.Errorf("fetched vulncheck result differs from returned (-returned, +fetched):\n%s", diff) + } } } @@ -463,7 +494,7 @@ func TestRunVulncheckPackageDiagnostics(t *testing.T) { ReadDiagnostics("go.mod", gotDiagnostics), ) - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{ + testFetchVulncheckResult(t, env, "", nil, map[string]fetchVulncheckResult{ "go.mod": { IDs: []string{"GO-2022-01", "GO-2022-02", "GO-2022-03"}, Mode: vulncheck.ModeImports, @@ -531,7 +562,7 @@ func TestRunVulncheckPackageDiagnostics(t *testing.T) { if len(gotDiagnostics.Diagnostics) > 0 { t.Errorf("Unexpected diagnostics: %v", stringify(gotDiagnostics)) } - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{}) + testFetchVulncheckResult(t, env, "", nil, map[string]fetchVulncheckResult{}) } for _, tc := range []struct { @@ -561,7 +592,7 @@ func TestRunVulncheckPackageDiagnostics(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(result.Token, nil), + CompletedProgress(server.GoVulncheckCommandTitle, nil), ShownMessage("Found"), ) env.OnceMet( @@ -609,7 +640,7 @@ func TestRunGovulncheck_Expiry(t *testing.T) { var result command.RunVulncheckResult env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) env.OnceMet( - CompletedProgress(result.Token, nil), + CompletedProgress(server.GoVulncheckCommandTitle, nil), ShownMessage("Found"), ) // Sleep long enough for the results to expire. @@ -640,7 +671,7 @@ func TestRunVulncheckWarning(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(result.Token, nil), + CompletedProgress(server.GoVulncheckCommandTitle, nil), ShownMessage("Found"), ) // Vulncheck diagnostics asynchronous to the vulncheck command. @@ -649,7 +680,7 @@ func TestRunVulncheckWarning(t *testing.T) { ReadDiagnostics("go.mod", gotDiagnostics), ) - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{ + testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ // All vulnerabilities (symbol-level, import-level, module-level) are reported. "go.mod": {IDs: []string{"GO-2022-01", "GO-2022-02", "GO-2022-03", "GO-2022-04"}, Mode: vulncheck.ModeGovulncheck}, }) @@ -795,7 +826,7 @@ func TestGovulncheckInfo(t *testing.T) { env.ExecuteCodeLensCommand("go.mod", command.RunGovulncheck, &result) gotDiagnostics := &protocol.PublishDiagnosticsParams{} env.OnceMet( - CompletedProgress(result.Token, nil), + CompletedProgress(server.GoVulncheckCommandTitle, nil), ShownMessage("No vulnerabilities found"), // only count affecting vulnerabilities. ) @@ -805,7 +836,9 @@ func TestGovulncheckInfo(t *testing.T) { ReadDiagnostics("go.mod", gotDiagnostics), ) - testFetchVulncheckResult(t, env, map[string]fetchVulncheckResult{"go.mod": {IDs: []string{"GO-2022-02", "GO-2022-04"}, Mode: vulncheck.ModeGovulncheck}}) + testFetchVulncheckResult(t, env, "go.mod", result.Result, map[string]fetchVulncheckResult{ + "go.mod": {IDs: []string{"GO-2022-02", "GO-2022-04"}, Mode: vulncheck.ModeGovulncheck}, + }) // wantDiagnostics maps a module path in the require // section of a go.mod to diagnostics that will be returned // when running vulncheck. diff --git a/gopls/internal/test/integration/wrappers.go b/gopls/internal/test/integration/wrappers.go index 4e59f1a5b..68d23ddd2 100644 --- a/gopls/internal/test/integration/wrappers.go +++ b/gopls/internal/test/integration/wrappers.go @@ -5,7 +5,6 @@ package integration import ( - "encoding/json" "errors" "os" "path" @@ -387,46 +386,22 @@ func (e *Env) CodeLens(path string) []protocol.CodeLens { // ExecuteCodeLensCommand executes the command for the code lens matching the // given command name. -func (e *Env) ExecuteCodeLensCommand(path string, cmd command.Command, result interface{}) { +// +// result is a pointer to a variable to be populated by json.Unmarshal. +func (e *Env) ExecuteCodeLensCommand(path string, cmd command.Command, result any) { e.T.Helper() - lenses := e.CodeLens(path) - var lens protocol.CodeLens - var found bool - for _, l := range lenses { - if l.Command.Command == cmd.String() { - lens = l - found = true - } + if err := e.Editor.ExecuteCodeLensCommand(e.Ctx, path, cmd, result); err != nil { + e.T.Fatal(err) } - if !found { - e.T.Fatalf("found no command with the ID %s", cmd) - } - e.ExecuteCommand(&protocol.ExecuteCommandParams{ - Command: lens.Command.Command, - Arguments: lens.Command.Arguments, - }, result) } -func (e *Env) ExecuteCommand(params *protocol.ExecuteCommandParams, result interface{}) { +// ExecuteCommand executes the requested command in the editor, calling t.Fatal +// on any error. +// +// result is a pointer to a variable to be populated by json.Unmarshal. +func (e *Env) ExecuteCommand(params *protocol.ExecuteCommandParams, result any) { e.T.Helper() - response, err := e.Editor.ExecuteCommand(e.Ctx, params) - if err != nil { - e.T.Fatal(err) - } - if result == nil { - return - } - // Hack: The result of an executeCommand request will be unmarshaled into - // maps. Re-marshal and unmarshal into the type we expect. - // - // This could be improved by generating a jsonrpc2 command client from the - // command.Interface, but that should only be done if we're consolidating - // this part of the tsprotocol generation. - data, err := json.Marshal(response) - if err != nil { - e.T.Fatal(err) - } - if err := json.Unmarshal(data, result); err != nil { + if err := e.Editor.ExecuteCommand(e.Ctx, params, result); err != nil { e.T.Fatal(err) } } diff --git a/gopls/internal/vulncheck/scan/command.go b/gopls/internal/vulncheck/scan/command.go index 4ef005010..1b703a720 100644 --- a/gopls/internal/vulncheck/scan/command.go +++ b/gopls/internal/vulncheck/scan/command.go @@ -91,7 +91,7 @@ func RunGovulncheck(ctx context.Context, pattern string, snapshot *cache.Snapsho if stderr.Len() > 0 { log.Write(stderr.Bytes()) } - return nil, fmt.Errorf("failed to read govulncheck output: %v", err) + return nil, fmt.Errorf("failed to read govulncheck output: %v: stderr:\n%s", err, stderr) } findings := handler.findings // sort so the findings in the result is deterministic. diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go index 418bd6804..27cb10892 100644 --- a/internal/jsonrpc2/handler.go +++ b/internal/jsonrpc2/handler.go @@ -27,8 +27,8 @@ func MethodNotFound(ctx context.Context, reply Replier, req Request) error { return reply(ctx, nil, fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method())) } -// MustReplyHandler creates a Handler that panics if the wrapped handler does -// not call Reply for every request that it is passed. +// MustReplyHandler is a middleware that creates a Handler that panics if the +// wrapped handler does not call Reply for every request that it is passed. func MustReplyHandler(handler Handler) Handler { return func(ctx context.Context, reply Replier, req Request) error { called := false @@ -78,8 +78,8 @@ func CancelHandler(handler Handler) (Handler, func(id ID)) { } } -// AsyncHandler returns a handler that processes each request goes in its own -// goroutine. +// AsyncHandler is a middleware that returns a handler that processes each +// request goes in its own goroutine. // The handler returns immediately, without the request being processed. // Each request then waits for the previous request to finish before it starts. // This allows the stream to unblock at the cost of unbounded goroutines @@ -90,13 +90,14 @@ func AsyncHandler(handler Handler) Handler { return func(ctx context.Context, reply Replier, req Request) error { waitForPrevious := nextRequest nextRequest = make(chan struct{}) - unlockNext := nextRequest + releaser := &releaser{ch: nextRequest} innerReply := reply reply = func(ctx context.Context, result interface{}, err error) error { - close(unlockNext) + releaser.release(true) return innerReply(ctx, result, err) } _, queueDone := event.Start(ctx, "queued") + ctx = context.WithValue(ctx, asyncKey, releaser) go func() { <-waitForPrevious queueDone() @@ -107,3 +108,46 @@ func AsyncHandler(handler Handler) Handler { return nil } } + +// Async, when used with the [AsyncHandler] middleware, indicates that the +// current jsonrpc2 request may be handled asynchronously to subsequent +// requests. +// +// When not used with an AsyncHandler, Async is a no-op. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +}