gopls/internal/golang/extract: preserve comments in extracted block

Use printer.CommentedNode to preserve comments in function and method
extraction.

Fixes golang/go#50851

Change-Id: I7d8aa2683c980e613592f64646f8077952ea61be
Reviewed-on: https://go-review.googlesource.com/c/tools/+/629376
Reviewed-by: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Rob Findley 2024-11-19 14:54:55 +00:00 коммит произвёл Robert Findley
Родитель 8c3ba8c103
Коммит 9dff42e52e
6 изменённых файлов: 115 добавлений и 98 удалений

Просмотреть файл

@ -10,6 +10,7 @@ import (
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"go/types"
"slices"
@ -449,7 +450,8 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte
return nil, nil, err
}
selection := src[startOffset:endOffset]
extractedBlock, err := parseBlockStmt(fset, selection)
extractedBlock, extractedComments, err := parseStmts(fset, selection)
if err != nil {
return nil, nil, err
}
@ -570,40 +572,16 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte
if canDefine {
sym = token.DEFINE
}
var name, funName string
var funName string
if isMethod {
name = "newMethod"
// TODO(suzmue): generate a name that does not conflict for "newMethod".
funName = name
funName = "newMethod"
} else {
name = "newFunction"
funName, _ = generateAvailableName(start, path, pkg, info, name, 0)
funName, _ = generateAvailableName(start, path, pkg, info, "newFunction", 0)
}
extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params,
append(returns, getNames(retVars)...), funName, sym, receiverName)
// Build the extracted function.
newFunc := &ast.FuncDecl{
Name: ast.NewIdent(funName),
Type: &ast.FuncType{
Params: &ast.FieldList{List: paramTypes},
Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
},
Body: extractedBlock,
}
if isMethod {
var names []*ast.Ident
if receiverUsed {
names = append(names, ast.NewIdent(receiverName))
}
newFunc.Recv = &ast.FieldList{
List: []*ast.Field{{
Names: names,
Type: receiver.Type,
}},
}
}
// Create variable declarations for any identifiers that need to be initialized prior to
// calling the extracted function. We do not manually initialize variables if every return
// value is uninitialized. We can use := to initialize the variables in this situation.
@ -624,17 +602,49 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte
return nil, nil, err
}
}
// Build the extracted function. We format the function declaration and body
// separately, so that comments are printed relative to the extracted
// BlockStmt.
//
// In other words, extractedBlock and extractedComments were parsed from a
// synthetic function declaration of the form func _() { ... }. If we now
// print the real function declaration, the length of the signature will have
// grown, causing some comment positions to be computed as inside the
// signature itself.
newFunc := &ast.FuncDecl{
Name: ast.NewIdent(funName),
Type: &ast.FuncType{
Params: &ast.FieldList{List: paramTypes},
Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
},
// Body handled separately -- see above.
}
if isMethod {
var names []*ast.Ident
if receiverUsed {
names = append(names, ast.NewIdent(receiverName))
}
newFunc.Recv = &ast.FieldList{
List: []*ast.Field{{
Names: names,
Type: receiver.Type,
}},
}
}
if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
return nil, nil, err
}
// Find all the comments within the range and print them to be put somewhere.
// TODO(suzmue): print these in the extracted function at the correct place.
for _, cg := range file.Comments {
if cg.Pos().IsValid() && cg.Pos() < end && cg.Pos() >= start {
for _, c := range cg.List {
fmt.Fprintln(&commentBuf, c.Text)
// Write a space between the end of the function signature and opening '{'.
if err := newFuncBuf.WriteByte(' '); err != nil {
return nil, nil, err
}
commentedNode := &printer.CommentedNode{
Node: extractedBlock,
Comments: extractedComments,
}
if err := format.Node(&newFuncBuf, fset, commentedNode); err != nil {
return nil, nil, err
}
// We're going to replace the whole enclosing function,
@ -1187,25 +1197,25 @@ func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFr
return isOverriden
}
// parseBlockStmt generates an AST file from the given text. We then return the portion of the
// file that represents the text.
func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
// parseStmts parses the specified source (a list of statements) and
// returns them as a BlockStmt along with any associated comments.
func parseStmts(fset *token.FileSet, src []byte) (*ast.BlockStmt, []*ast.CommentGroup, error) {
text := "package main\nfunc _() { " + string(src) + " }"
extract, err := parser.ParseFile(fset, "", text, parser.SkipObjectResolution)
file, err := parser.ParseFile(fset, "", text, parser.ParseComments|parser.SkipObjectResolution)
if err != nil {
return nil, err
return nil, nil, err
}
if len(extract.Decls) == 0 {
return nil, fmt.Errorf("parsed file does not contain any declarations")
if len(file.Decls) != 1 {
return nil, nil, fmt.Errorf("got %d declarations, want 1", len(file.Decls))
}
decl, ok := extract.Decls[0].(*ast.FuncDecl)
decl, ok := file.Decls[0].(*ast.FuncDecl)
if !ok {
return nil, fmt.Errorf("parsed file does not contain expected function declaration")
return nil, nil, bug.Errorf("parsed file does not contain expected function declaration")
}
if decl.Body == nil {
return nil, fmt.Errorf("extracted function has no body")
return nil, nil, bug.Errorf("extracted function has no body")
}
return decl.Body, nil
return decl.Body, file.Comments, nil
}
// generateReturnInfo generates the information we need to adjust the return statements and

Просмотреть файл

@ -237,20 +237,14 @@ func (b *B) LongListWithT(ctx context.Context, t *testing.T) (int, error) {
+}
+
-- @contextFuncB/context.go --
@@ -33 +33,6 @@
- sum := b.x + b.y //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`)
+ //@loc(B_AddPWithB, re`(?s:^.*?Err\(\))`)
@@ -33 +33,4 @@
+ return newFunction(ctx, tB, b)
+}
+
+func newFunction(ctx context.Context, tB *testing.B, b *B) (int, error) {
+ sum := b.x + b.y
-- @contextFuncT/context.go --
@@ -42 +42,6 @@
- p4 := p1 + p2 //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`)
+ //@loc(B_LongListWithT, re`(?s:^.*?Err\(\))`)
@@ -42 +42,4 @@
+ return newFunction(ctx, t, p1, p2, p3)
+}
+
+func newFunction(ctx context.Context, t *testing.T, p1 int, p2 int, p3 int) (int, error) {
+ p4 := p1 + p2

Просмотреть файл

@ -17,12 +17,11 @@ func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, res
package extract
func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer)
//@codeaction("a", "refactor.extract.function", end=end, result=inner)
newFunction() //@loc(end, "4")
}
func newFunction() {
a := 1
a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner)
_ = a + 4
} //@loc(closeBracket, "}")
@ -30,12 +29,11 @@ func newFunction() {
package extract
func _() { //@codeaction("{", "refactor.extract.function", end=closeBracket, result=outer)
//@codeaction("a", "refactor.extract.function", end=end, result=inner)
newFunction() //@loc(end, "4")
}
func newFunction() {
a := 1
a := 1 //@codeaction("a", "refactor.extract.function", end=end, result=inner)
_ = a + 4
} //@loc(closeBracket, "}")
@ -55,7 +53,6 @@ package extract
func _() bool {
x := 1
//@codeaction("if", "refactor.extract.function", end=ifend, result=return)
shouldReturn, b := newFunction(x)
if shouldReturn {
return b
@ -64,7 +61,7 @@ func _() bool {
}
func newFunction(x int) (bool, bool) {
if x == 0 {
if x == 0 { //@codeaction("if", "refactor.extract.function", end=ifend, result=return)
return true, true
}
return false, false
@ -85,12 +82,11 @@ func _() bool {
package extract
func _() bool {
//@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn)
return newFunction() //@loc(rnnEnd, "false")
}
func newFunction() bool {
x := 1
x := 1 //@codeaction("x", "refactor.extract.function", end=rnnEnd, result=rnn)
if x == 0 {
return true
}
@ -123,7 +119,6 @@ import "fmt"
func _() (int, string, error) {
x := 1
y := "hello"
//@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc)
z, shouldReturn, i, s, err := newFunction(y, x)
if shouldReturn {
return i, s, err
@ -132,7 +127,7 @@ func _() (int, string, error) {
}
func newFunction(y string, x int) (string, bool, int, string, error) {
z := "bye"
z := "bye" //@codeaction("z", "refactor.extract.function", end=rcEnd, result=rc)
if y == z {
return "", true, x, y, fmt.Errorf("same")
} else if false {
@ -168,12 +163,11 @@ import "fmt"
func _() (int, string, error) {
x := 1
y := "hello"
//@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn)
return newFunction(y, x) //@loc(rcnnEnd, "nil")
}
func newFunction(y string, x int) (int, string, error) {
z := "bye"
z := "bye" //@codeaction("z", "refactor.extract.function", end=rcnnEnd, result=rcnn)
if y == z {
return x, y, fmt.Errorf("same")
} else if false {
@ -204,7 +198,6 @@ import "go/ast"
func _() {
ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool {
//@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl)
shouldReturn, b := newFunction(n)
if shouldReturn {
return b
@ -214,7 +207,7 @@ func _() {
}
func newFunction(n ast.Node) (bool, bool) {
if n == nil {
if n == nil { //@codeaction("if", "refactor.extract.function", end=rflEnd, result=rfl)
return true, true
}
return false, false
@ -241,13 +234,12 @@ import "go/ast"
func _() {
ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool {
//@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn)
return newFunction(n) //@loc(rflnnEnd, "false")
})
}
func newFunction(n ast.Node) bool {
if n == nil {
if n == nil { //@codeaction("if", "refactor.extract.function", end=rflnnEnd, result=rflnn)
return true
}
return false
@ -271,7 +263,6 @@ package extract
func _() string {
x := 1
//@codeaction("if", "refactor.extract.function", end=riEnd, result=ri)
shouldReturn, s := newFunction(x)
if shouldReturn {
return s
@ -281,7 +272,7 @@ func _() string {
}
func newFunction(x int) (bool, string) {
if x == 0 {
if x == 0 { //@codeaction("if", "refactor.extract.function", end=riEnd, result=ri)
x = 3
return true, "a"
}
@ -306,12 +297,11 @@ package extract
func _() string {
x := 1
//@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn)
return newFunction(x) //@loc(rinnEnd, "\"b\"")
}
func newFunction(x int) string {
if x == 0 {
if x == 0 { //@codeaction("if", "refactor.extract.function", end=rinnEnd, result=rinn)
x = 3
return "a"
}
@ -336,7 +326,6 @@ package extract
func _() {
a := 1
//@codeaction("a", "refactor.extract.function", end=araend, result=ara)
a = newFunction(a) //@loc(araend, "2")
b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb)
@ -344,7 +333,7 @@ func _() {
}
func newFunction(a int) int {
a = 5
a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara)
a = a + 2
return a
}
@ -357,12 +346,11 @@ func _() {
a = 5 //@codeaction("a", "refactor.extract.function", end=araend, result=ara)
a = a + 2 //@loc(araend, "2")
//@codeaction("b", "refactor.extract.function", end=arbend, result=arb)
newFunction(a) //@loc(arbend, "4")
}
func newFunction(a int) {
b := a * 2
b := a * 2 //@codeaction("b", "refactor.extract.function", end=arbend, result=arb)
_ = b + 4
}
@ -412,13 +400,12 @@ package extract
func _() {
var a []int
//@codeaction("a", "refactor.extract.function", end=siEnd, result=si)
a, b := newFunction(a) //@loc(siEnd, "4")
a = append(a, b)
}
func newFunction(a []int) ([]int, int) {
a = append(a, 2)
a = append(a, 2) //@codeaction("a", "refactor.extract.function", end=siEnd, result=si)
b := 4
return a, b
}
@ -441,13 +428,12 @@ package extract
func _() {
var b []int
var a int
//@codeaction("a", "refactor.extract.function", end=srEnd, result=sr)
b = newFunction(a, b) //@loc(srEnd, ")")
b[0] = 1
}
func newFunction(a int, b []int) []int {
a = 2
a = 2 //@codeaction("a", "refactor.extract.function", end=srEnd, result=sr)
b = []int{}
b = append(b, a)
return b
@ -472,7 +458,6 @@ package extract
func _() {
var b []int
//@codeaction("a", "refactor.extract.function", end=upEnd, result=up)
a, b := newFunction(b) //@loc(upEnd, ")")
b[0] = 1
if a == 2 {
@ -481,7 +466,7 @@ func _() {
}
func newFunction(b []int) (int, []int) {
a := 2
a := 2 //@codeaction("a", "refactor.extract.function", end=upEnd, result=up)
b = []int{}
b = append(b, a)
return a, b
@ -503,9 +488,6 @@ func _() {
package extract
func _() {
/* comment in the middle of a line */
//@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1)
// Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2)
newFunction() //@loc(commentEnd, "4"),codeaction("_", "refactor.extract.function", end=lastComment, result=comment3)
// Comment right after 3 + 4
@ -513,8 +495,8 @@ func _() {
}
func newFunction() {
a := 1
a := /* comment in the middle of a line */ 1 //@codeaction("a", "refactor.extract.function", end=commentEnd, result=comment1)
// Comment on its own line //@codeaction("Comment", "refactor.extract.function", end=commentEnd, result=comment2)
_ = a + 4
}
@ -602,12 +584,11 @@ import "slices"
// issue go#64821
func _() {
//@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1)
newFunction() //@loc(anonEnd, ")")
}
func newFunction() {
var s []string
var s []string //@codeaction("var", "refactor.extract.function", end=anonEnd, result=anon1)
slices.SortFunc(s, func(a, b string) int {
return cmp.Compare(a, b)
})

Просмотреть файл

@ -26,13 +26,12 @@ package extract
import "fmt"
func main() {
//@codeaction("x", "refactor.extract.function", end=end, result=ext)
x := newFunction() //@loc(end, "}")
fmt.Printf("%x\n", x)
}
func newFunction() []rune {
x := []rune{}
x := []rune{} //@codeaction("x", "refactor.extract.function", end=end, result=ext)
s := "HELLO"
for _, c := range s {
x = append(x, c)

Просмотреть файл

@ -0,0 +1,35 @@
This test checks that function extraction moves comments along with the
extracted code.
-- main.go --
package main
type F struct{}
func (f *F) func1() {
println("a")
println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result)
// This line prints the third letter of the alphabet.
println("c") //@loc(end, ")")
println("d")
}
-- @result/main.go --
package main
type F struct{}
func (f *F) func1() {
println("a")
newFunction() //@loc(end, ")")
println("d")
}
func newFunction() {
println("b") //@ codeaction("print", "refactor.extract.function", end=end, result=result)
// This line prints the third letter of the alphabet.
println("c")
}

Просмотреть файл

@ -29,7 +29,6 @@ import (
)
func F() error {
//@codeaction("a", "refactor.extract.function", end=endF, result=F)
a, b, shouldReturn, err := newFunction()
if shouldReturn {
return err
@ -39,7 +38,7 @@ func F() error {
}
func newFunction() ([]byte, []byte, bool, error) {
a, err := json.Marshal(0)
a, err := json.Marshal(0) //@codeaction("a", "refactor.extract.function", end=endF, result=F)
if err != nil {
return nil, nil, true, fmt.Errorf("1: %w", err)
}
@ -78,7 +77,6 @@ import (
)
func G() (x, y int) {
//@codeaction("v", "refactor.extract.function", end=endG, result=G)
v, shouldReturn, x1, y1 := newFunction()
if shouldReturn {
return x1, y1
@ -88,7 +86,7 @@ func G() (x, y int) {
}
func newFunction() (int, bool, int, int) {
v := rand.Int()
v := rand.Int() //@codeaction("v", "refactor.extract.function", end=endG, result=G)
if v < 0 {
return 0, true, 1, 2
}