diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 06065b406..64227c71b 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -10,6 +10,7 @@ import ( "go/ast" "go/format" "go/parser" + "go/printer" "go/token" "go/types" "slices" @@ -449,7 +450,8 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte return nil, nil, err } selection := src[startOffset:endOffset] - extractedBlock, err := parseBlockStmt(fset, selection) + + extractedBlock, extractedComments, err := parseStmts(fset, selection) if err != nil { return nil, nil, err } @@ -570,40 +572,16 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte if canDefine { sym = token.DEFINE } - var name, funName string + var funName string if isMethod { - name = "newMethod" // TODO(suzmue): generate a name that does not conflict for "newMethod". - funName = name + funName = "newMethod" } else { - name = "newFunction" - funName, _ = generateAvailableName(start, path, pkg, info, name, 0) + funName, _ = generateAvailableName(start, path, pkg, info, "newFunction", 0) } extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, append(returns, getNames(retVars)...), funName, sym, receiverName) - // Build the extracted function. - newFunc := &ast.FuncDecl{ - Name: ast.NewIdent(funName), - Type: &ast.FuncType{ - Params: &ast.FieldList{List: paramTypes}, - Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, - }, - Body: extractedBlock, - } - if isMethod { - var names []*ast.Ident - if receiverUsed { - names = append(names, ast.NewIdent(receiverName)) - } - newFunc.Recv = &ast.FieldList{ - List: []*ast.Field{{ - Names: names, - Type: receiver.Type, - }}, - } - } - // Create variable declarations for any identifiers that need to be initialized prior to // calling the extracted function. We do not manually initialize variables if every return // value is uninitialized. We can use := to initialize the variables in this situation. @@ -624,17 +602,49 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte return nil, nil, err } } + + // Build the extracted function. We format the function declaration and body + // separately, so that comments are printed relative to the extracted + // BlockStmt. + // + // In other words, extractedBlock and extractedComments were parsed from a + // synthetic function declaration of the form func _() { ... }. If we now + // print the real function declaration, the length of the signature will have + // grown, causing some comment positions to be computed as inside the + // signature itself. + newFunc := &ast.FuncDecl{ + Name: ast.NewIdent(funName), + Type: &ast.FuncType{ + Params: &ast.FieldList{List: paramTypes}, + Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, + }, + // Body handled separately -- see above. + } + if isMethod { + var names []*ast.Ident + if receiverUsed { + names = append(names, ast.NewIdent(receiverName)) + } + newFunc.Recv = &ast.FieldList{ + List: []*ast.Field{{ + Names: names, + Type: receiver.Type, + }}, + } + } if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { return nil, nil, err } - // Find all the comments within the range and print them to be put somewhere. - // TODO(suzmue): print these in the extracted function at the correct place. - for _, cg := range file.Comments { - if cg.Pos().IsValid() && cg.Pos() < end && cg.Pos() >= start { - for _, c := range cg.List { - fmt.Fprintln(&commentBuf, c.Text) - } - } + // Write a space between the end of the function signature and opening '{'. + if err := newFuncBuf.WriteByte(' '); err != nil { + return nil, nil, err + } + commentedNode := &printer.CommentedNode{ + Node: extractedBlock, + Comments: extractedComments, + } + if err := format.Node(&newFuncBuf, fset, commentedNode); err != nil { + return nil, nil, err } // We're going to replace the whole enclosing function, @@ -1187,25 +1197,25 @@ func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFr return isOverriden } -// parseBlockStmt generates an AST file from the given text. We then return the portion of the -// file that represents the text. -func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { +// parseStmts parses the specified source (a list of statements) and +// returns them as a BlockStmt along with any associated comments. +func parseStmts(fset *token.FileSet, src []byte) (*ast.BlockStmt, []*ast.CommentGroup, error) { text := "package main\nfunc _() { " + string(src) + " }" - extract, err := parser.ParseFile(fset, "", text, parser.SkipObjectResolution) + file, err := parser.ParseFile(fset, "", text, parser.ParseComments|parser.SkipObjectResolution) if err != nil { - return nil, err + return nil, nil, err } - if len(extract.Decls) == 0 { - return nil, fmt.Errorf("parsed file does not contain any declarations") + if len(file.Decls) != 1 { + return nil, nil, fmt.Errorf("got %d declarations, want 1", len(file.Decls)) } - decl, ok := extract.Decls[0].(*ast.FuncDecl) + decl, ok := file.Decls[0].(*ast.FuncDecl) if !ok { - return nil, fmt.Errorf("parsed file does not contain expected function declaration") + return nil, nil, bug.Errorf("parsed file does not contain expected function declaration") } if decl.Body == nil { - return nil, fmt.Errorf("extracted function has no body") + return nil, nil, bug.Errorf("extracted function has no body") } - return decl.Body, nil + return decl.Body, file.Comments, nil } // generateReturnInfo generates the information we need to adjust the return statements and diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_method.txt b/gopls/internal/test/marker/testdata/codeaction/extract_method.txt index 44517e259..49388f5bc 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_method.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_method.txt @@ -237,20 +237,14 @@ func (b *B) LongListWithT(ctx context.Context, t *testing.T) (int, error) { +} + -- @contextFuncB/context.go -- -@@ -33 +33,6 @@ -- sum := b.x + b.y //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`) -+ //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`) +@@ -33 +33,4 @@ + return newFunction(ctx, tB, b) +} + +func newFunction(ctx context.Context, tB *testing.B, b *B) (int, error) { -+ sum := b.x + b.y -- @contextFuncT/context.go -- -@@ -42 +42,6 @@ -- p4 := p1 + p2 //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`) -+ //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`) +@@ -42 +42,4 @@ + return newFunction(ctx, t, p1, p2, p3) +} + +func newFunction(ctx context.Context, t *testing.T, p1 int, p2 int, p3 int) (int, error) { -+ p4 := p1 + p2 diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt index f8081b41c..f84eeae7b 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt @@ -17,12 +17,11 @@ func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, res package extract func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer) - //@codeaction("a", "refactor.extract.function", end=end, result=inner) newFunction() //@loc(end, "4") } func newFunction() { - a := 1 + a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner) _ = a + 4 } //@loc(closeBracket, "}") @@ -30,12 +29,11 @@ func newFunction() { package extract func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer) - //@codeaction("a", "refactor.extract.function", end=end, result=inner) newFunction() //@loc(end, "4") } func newFunction() { - a := 1 + a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner) _ = a + 4 } //@loc(closeBracket, "}") @@ -55,7 +53,6 @@ package extract func _() bool { x := 1 - //@codeaction("if", "refactor.extract.function", end=ifend, result=return) shouldReturn, b := newFunction(x) if shouldReturn { return b @@ -64,7 +61,7 @@ func _() bool { } func newFunction(x int) (bool, bool) { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=ifend, result=return) return true, true } return false, false @@ -85,12 +82,11 @@ func _() bool { package extract func _() bool { - //@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn) return newFunction() //@loc(rnnEnd, "false") } func newFunction() bool { - x := 1 + x := 1 //@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn) if x == 0 { return true } @@ -123,7 +119,6 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc) z, shouldReturn, i, s, err := newFunction(y, x) if shouldReturn { return i, s, err @@ -132,7 +127,7 @@ func _() (int, string, error) { } func newFunction(y string, x int) (string, bool, int, string, error) { - z := "bye" + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc) if y == z { return "", true, x, y, fmt.Errorf("same") } else if false { @@ -168,12 +163,11 @@ import "fmt" func _() (int, string, error) { x := 1 y := "hello" - //@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn) return newFunction(y, x) //@loc(rcnnEnd, "nil") } func newFunction(y string, x int) (int, string, error) { - z := "bye" + z := "bye" //@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn) if y == z { return x, y, fmt.Errorf("same") } else if false { @@ -204,7 +198,6 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - //@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl) shouldReturn, b := newFunction(n) if shouldReturn { return b @@ -214,7 +207,7 @@ func _() { } func newFunction(n ast.Node) (bool, bool) { - if n == nil { + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl) return true, true } return false, false @@ -241,13 +234,12 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { - //@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn) return newFunction(n) //@loc(rflnnEnd, "false") }) } func newFunction(n ast.Node) bool { - if n == nil { + if n == nil { //@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn) return true } return false @@ -271,7 +263,6 @@ package extract func _() string { x := 1 - //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri) shouldReturn, s := newFunction(x) if shouldReturn { return s @@ -281,7 +272,7 @@ func _() string { } func newFunction(x int) (bool, string) { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri) x = 3 return true, "a" } @@ -306,12 +297,11 @@ package extract func _() string { x := 1 - //@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn) return newFunction(x) //@loc(rinnEnd, "\"b\"") } func newFunction(x int) string { - if x == 0 { + if x == 0 { //@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn) x = 3 return "a" } @@ -336,7 +326,6 @@ package extract func _() { a := 1 - //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = newFunction(a) //@loc(araend, "2") b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) @@ -344,7 +333,7 @@ func _() { } func newFunction(a int) int { - a = 5 + a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = a + 2 return a } @@ -357,12 +346,11 @@ func _() { a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara) a = a + 2 //@loc(araend, "2") - //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) newFunction(a) //@loc(arbend, "4") } func newFunction(a int) { - b := a * 2 + b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb) _ = b + 4 } @@ -412,13 +400,12 @@ package extract func _() { var a []int - //@codeaction("a", "refactor.extract.function", end=siEnd, result=si) a, b := newFunction(a) //@loc(siEnd, "4") a = append(a, b) } func newFunction(a []int) ([]int, int) { - a = append(a, 2) + a = append(a, 2) //@codeaction("a", "refactor.extract.function", end=siEnd, result=si) b := 4 return a, b } @@ -441,13 +428,12 @@ package extract func _() { var b []int var a int - //@codeaction("a", "refactor.extract.function", end=srEnd, result=sr) b = newFunction(a, b) //@loc(srEnd, ")") b[0] = 1 } func newFunction(a int, b []int) []int { - a = 2 + a = 2 //@codeaction("a", "refactor.extract.function", end=srEnd, result=sr) b = []int{} b = append(b, a) return b @@ -472,7 +458,6 @@ package extract func _() { var b []int - //@codeaction("a", "refactor.extract.function", end=upEnd, result=up) a, b := newFunction(b) //@loc(upEnd, ")") b[0] = 1 if a == 2 { @@ -481,7 +466,7 @@ func _() { } func newFunction(b []int) (int, []int) { - a := 2 + a := 2 //@codeaction("a", "refactor.extract.function", end=upEnd, result=up) b = []int{} b = append(b, a) return a, b @@ -503,9 +488,6 @@ func _() { package extract func _() { - /* comment in the middle of a line */ - //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) - // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) newFunction() //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3) // Comment right after 3 + 4 @@ -513,8 +495,8 @@ func _() { } func newFunction() { - a := 1 - + a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1) + // Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2) _ = a + 4 } @@ -602,12 +584,11 @@ import "slices" // issue go#64821 func _() { - //@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1) newFunction() //@loc(anonEnd, ")") } func newFunction() { - var s []string + var s []string //@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1) slices.SortFunc(s, func(a, b string) int { return cmp.Compare(a, b) }) diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt index 7ba89f4df..c1302b1bf 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue44813.txt @@ -26,13 +26,12 @@ package extract import "fmt" func main() { - //@codeaction("x", "refactor.extract.function", end=end, result=ext) x := newFunction() //@loc(end, "}") fmt.Printf("%x\n", x) } func newFunction() []rune { - x := []rune{} + x := []rune{} //@codeaction("x", "refactor.extract.function", end=end, result=ext) s := "HELLO" for _, c := range s { x = append(x, c) diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt new file mode 100644 index 000000000..b085559cf --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue50851.txt @@ -0,0 +1,35 @@ +This test checks that function extraction moves comments along with the +extracted code. + +-- main.go -- +package main + +type F struct{} + +func (f *F) func1() { + println("a") + + println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result) + // This line prints the third letter of the alphabet. + println("c") //@loc(end, ")") + + println("d") +} +-- @result/main.go -- +package main + +type F struct{} + +func (f *F) func1() { + println("a") + + newFunction() //@loc(end, ")") + + println("d") +} + +func newFunction() { + println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result) + // This line prints the third letter of the alphabet. + println("c") +} diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt index 4444edd9d..30db2fb3e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt @@ -29,7 +29,6 @@ import ( ) func F() error { - //@codeaction("a", "refactor.extract.function", end=endF, result=F) a, b, shouldReturn, err := newFunction() if shouldReturn { return err @@ -39,7 +38,7 @@ func F() error { } func newFunction() ([]byte, []byte, bool, error) { - a, err := json.Marshal(0) + a, err := json.Marshal(0) //@codeaction("a", "refactor.extract.function", end=endF, result=F) if err != nil { return nil, nil, true, fmt.Errorf("1: %w", err) } @@ -78,7 +77,6 @@ import ( ) func G() (x, y int) { - //@codeaction("v", "refactor.extract.function", end=endG, result=G) v, shouldReturn, x1, y1 := newFunction() if shouldReturn { return x1, y1 @@ -88,7 +86,7 @@ func G() (x, y int) { } func newFunction() (int, bool, int, int) { - v := rand.Int() + v := rand.Int() //@codeaction("v", "refactor.extract.function", end=endG, result=G) if v < 0 { return 0, true, 1, 2 }