diff --git a/go/ast/astutil/enclosing.go b/go/ast/astutil/enclosing.go index 6b7052b89..a5c6d6d4f 100644 --- a/go/ast/astutil/enclosing.go +++ b/go/ast/astutil/enclosing.go @@ -11,6 +11,8 @@ import ( "go/ast" "go/token" "sort" + + "golang.org/x/tools/internal/typeparams" ) // PathEnclosingInterval returns the node that encloses the source @@ -294,8 +296,8 @@ func childrenOf(n ast.Node) []ast.Node { case *ast.FieldList: children = append(children, - tok(n.Opening, len("(")), - tok(n.Closing, len(")"))) + tok(n.Opening, len("(")), // or len("[") + tok(n.Closing, len(")"))) // or len("]") case *ast.File: // TODO test: Doc @@ -322,6 +324,9 @@ func childrenOf(n ast.Node) []ast.Node { children = append(children, n.Recv) } children = append(children, n.Name) + if tparams := typeparams.ForFuncType(n.Type); tparams != nil { + children = append(children, tparams) + } if n.Type.Params != nil { children = append(children, n.Type.Params) } @@ -371,8 +376,13 @@ func childrenOf(n ast.Node) []ast.Node { case *ast.IndexExpr: children = append(children, - tok(n.Lbrack, len("{")), - tok(n.Rbrack, len("}"))) + tok(n.Lbrack, len("[")), + tok(n.Rbrack, len("]"))) + + case *typeparams.IndexListExpr: + children = append(children, + tok(n.Lbrack, len("[")), + tok(n.Rbrack, len("]"))) case *ast.InterfaceType: children = append(children, @@ -581,6 +591,8 @@ func NodeDescription(n ast.Node) string { return "decrement statement" case *ast.IndexExpr: return "index expression" + case *typeparams.IndexListExpr: + return "index list expression" case *ast.InterfaceType: return "interface type" case *ast.KeyValueExpr: diff --git a/go/ast/astutil/enclosing_test.go b/go/ast/astutil/enclosing_test.go index 107f87c55..5e86ff93c 100644 --- a/go/ast/astutil/enclosing_test.go +++ b/go/ast/astutil/enclosing_test.go @@ -19,6 +19,7 @@ import ( "testing" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/internal/typeparams" ) // pathToString returns a string containing the concrete types of the @@ -59,7 +60,10 @@ func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *a } // Common input for following tests. -const input = ` +var input = makeInput() + +func makeInput() string { + src := ` // Hello. package main import "fmt" @@ -70,52 +74,88 @@ func main() { } ` + if typeparams.Enabled { + src += ` +func g[A any, P interface{ctype1| ~ctype2}](a1 A, p1 P) {} + +type PT[T constraint] struct{ t T } + +var v GT[targ1] + +var h = g[ targ2, targ3] +` + } + return src +} + func TestPathEnclosingInterval_Exact(t *testing.T) { - // For the exact tests, we check that a substring is mapped to - // the canonical string for the node it denotes. - tests := []struct { + type testCase struct { substr string // first occurrence of this string indicates interval node string // complete text of expected containing node - }{ + } + + dup := func(s string) testCase { return testCase{s, s} } + // For the exact tests, we check that a substring is mapped to + // the canonical string for the node it denotes. + tests := []testCase{ {"package", input[11 : len(input)-1]}, {"\npack", input[11 : len(input)-1]}, - {"main", - "main"}, + dup("main"), {"import", "import \"fmt\""}, - {"\"fmt\"", - "\"fmt\""}, + dup("\"fmt\""), {"\nfunc f() {}\n", "func f() {}"}, {"x ", "x"}, {" y", "y"}, - {"z", - "z"}, + dup("z"), {" + ", "x + y"}, {" :=", "z := (x + y)"}, - {"x + y", - "x + y"}, - {"(x + y)", - "(x + y)"}, + dup("x + y"), + dup("(x + y)"), {" (x + y) ", "(x + y)"}, {" (x + y) // add", "(x + y)"}, {"func", "func f() {}"}, - {"func f() {}", - "func f() {}"}, + dup("func f() {}"), {"\nfun", "func f() {}"}, {" f", "f"}, } + if typeparams.Enabled { + tests = append(tests, []testCase{ + dup("[A any, P interface{ctype1| ~ctype2}]"), + {"[", "[A any, P interface{ctype1| ~ctype2}]"}, + dup("A"), + {" any", "any"}, + dup("ctype1"), + {"|", "ctype1| ~ctype2"}, + dup("ctype2"), + {"~", "~ctype2"}, + dup("~ctype2"), + {" ~ctype2", "~ctype2"}, + {"]", "[A any, P interface{ctype1| ~ctype2}]"}, + dup("a1"), + dup("a1 A"), + dup("(a1 A, p1 P)"), + dup("type PT[T constraint] struct{ t T }"), + dup("PT"), + dup("[T constraint]"), + dup("constraint"), + dup("targ1"), + {" targ2", "targ2"}, + dup("g[ targ2, targ3]"), + }...) + } for _, test := range tests { f, start, end := findInterval(t, new(token.FileSet), input, test.substr) if f == nil { @@ -145,13 +185,14 @@ func TestPathEnclosingInterval_Exact(t *testing.T) { } func TestPathEnclosingInterval_Paths(t *testing.T) { + type testCase struct { + substr string // first occurrence of this string indicates interval + path string // the pathToString(),exact of the expected path + } // For these tests, we check only the path of the enclosing // node, but not its complete text because it's often quite // large when !exact. - tests := []struct { - substr string // first occurrence of this string indicates interval - path string // the pathToString(),exact of the expected path - }{ + tests := []testCase{ {"// add", "[BlockStmt FuncDecl File],false"}, {"(x + y", @@ -179,6 +220,18 @@ func TestPathEnclosingInterval_Paths(t *testing.T) { {"f() // NB", "[CallExpr ExprStmt BlockStmt FuncDecl File],true"}, } + if typeparams.Enabled { + tests = append(tests, []testCase{ + {" any", "[Ident Field FieldList FuncDecl File],true"}, + {"|", "[BinaryExpr Field FieldList InterfaceType Field FieldList FuncDecl File],true"}, + {"ctype2", + "[Ident UnaryExpr BinaryExpr Field FieldList InterfaceType Field FieldList FuncDecl File],true"}, + {"a1", "[Ident Field FieldList FuncDecl File],true"}, + {"PT[T constraint]", "[TypeSpec GenDecl File],false"}, + {"[T constraint]", "[FieldList TypeSpec GenDecl File],true"}, + {"targ2", "[Ident IndexListExpr ValueSpec GenDecl File],true"}, + }...) + } for _, test := range tests { f, start, end := findInterval(t, new(token.FileSet), input, test.substr) if f == nil {