internal/refactor/inline: remove eta abstraction inlining assignments

Remove the unnecessary eta abstraction reported in golang/go#65217, by
introducing a new strategy for rewriting assignments.

This strategy involves analyzing both LHS and RHS of an assignment, and
choosing between three substrategies:
 - spread the result expressions in cases where types are unambiguous
 - predeclare LHS variables in cases where the return is itself a spread
   call
 - convert RHS expressions if types involve implicit conversions

Doing this involved some fixes to the logic for detecting trivial
conversions, and tracking some additional information about untyped nils
in return expressions.

Since this strategy avoids literalization by modifying assignments in
place, it must be able to avoid nested blocks, and so it explicitly
records that braces may be elided.

There is more work to be done here, both improving the writeType helper,
and by variable declarations, but this CL lays a foundation for later
expansion.

For golang/go#65217

Change-Id: I9b3b595f7f678ab9b86ef7cf19936fd818b45426
Reviewed-on: https://go-review.googlesource.com/c/tools/+/580835
Reviewed-by: Alan Donovan <adonovan@google.com>
Auto-Submit: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Rob Findley 2024-04-12 18:37:50 +00:00 коммит произвёл Gopher Robot
Родитель fcea13b66c
Коммит bcec0994e0
6 изменённых файлов: 780 добавлений и 124 удалений

Просмотреть файл

@ -0,0 +1,58 @@
This test reproduces condition of golang/go#65217, where the inliner created an
unnecessary eta abstraction.
-- go.mod --
module unused.mod
go 1.18
-- a/a.go --
package a
type S struct{}
func (S) Int() int { return 0 }
func _() {
var s S
_ = f(s, s.Int())
var j int
j = f(s, s.Int())
_ = j
}
func _() {
var s S
i := f(s, s.Int())
_ = i
}
func f(unused S, i int) int { //@codeaction("unused", "unused", "refactor.rewrite", rewrite, "Refactor: remove unused parameter"), diag("unused", re`unused`)
return i
}
-- @rewrite/a/a.go --
package a
type S struct{}
func (S) Int() int { return 0 }
func _() {
var s S
_ = f(s.Int())
var j int
j = f(s.Int())
_ = j
}
func _() {
var s S
var _ S = s
i := f(s.Int())
_ = i
}
func f(i int) int { //@codeaction("unused", "unused", "refactor.rewrite", rewrite, "Refactor: remove unused parameter"), diag("unused", re`unused`)
return i
}

Просмотреть файл

@ -32,25 +32,33 @@ type gobCallee struct {
Content []byte // file content, compacted to a single func decl
// results of type analysis (does not reach go/types data structures)
PkgPath string // package path of declaring package
Name string // user-friendly name for error messages
Unexported []string // names of free objects that are unexported
FreeRefs []freeRef // locations of references to free objects
FreeObjs []object // descriptions of free objects
ValidForCallStmt bool // function body is "return expr" where expr is f() or <-ch
NumResults int // number of results (according to type, not ast.FieldList)
Params []*paramInfo // information about parameters (incl. receiver)
Results []*paramInfo // information about result variables
Effects []int // order in which parameters are evaluated (see calleefx)
HasDefer bool // uses defer
HasBareReturn bool // uses bare return in non-void function
TotalReturns int // number of return statements
TrivialReturns int // number of return statements with trivial result conversions
Labels []string // names of all control labels
Falcon falconResult // falcon constraint system
PkgPath string // package path of declaring package
Name string // user-friendly name for error messages
Unexported []string // names of free objects that are unexported
FreeRefs []freeRef // locations of references to free objects
FreeObjs []object // descriptions of free objects
ValidForCallStmt bool // function body is "return expr" where expr is f() or <-ch
NumResults int // number of results (according to type, not ast.FieldList)
Params []*paramInfo // information about parameters (incl. receiver)
Results []*paramInfo // information about result variables
Effects []int // order in which parameters are evaluated (see calleefx)
HasDefer bool // uses defer
HasBareReturn bool // uses bare return in non-void function
Returns [][]returnOperandFlags // metadata about result expressions for each return
Labels []string // names of all control labels
Falcon falconResult // falcon constraint system
}
// A freeRef records a reference to a free object. Gob-serializable.
// returnOperandFlags records metadata about a single result expression in a return
// statement.
type returnOperandFlags int
const (
nonTrivialResult returnOperandFlags = 1 << iota // return operand has non-trivial conversion to result type
untypedNilResult // return operand is nil literal
)
// A freeRef records a reference to a free object. Gob-serializable.
// (This means free relative to the FuncDecl as a whole, i.e. excluding parameters.)
type freeRef struct {
Offset int // byte offset of the reference relative to the FuncDecl
@ -264,11 +272,10 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
// Record information about control flow in the callee
// (but not any nested functions).
var (
hasDefer = false
hasBareReturn = false
totalReturns = 0
trivialReturns = 0
labels []string
hasDefer = false
hasBareReturn = false
returnInfo [][]returnOperandFlags
labels []string
)
ast.Inspect(decl.Body, func(n ast.Node) bool {
switch n := n.(type) {
@ -279,34 +286,37 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
case *ast.LabeledStmt:
labels = append(labels, n.Label.Name)
case *ast.ReturnStmt:
totalReturns++
// Are implicit assignment conversions
// to result variables all trivial?
trivial := true
var resultInfo []returnOperandFlags
if len(n.Results) > 0 {
argType := func(i int) types.Type {
return info.TypeOf(n.Results[i])
argInfo := func(i int) (ast.Expr, types.Type) {
expr := n.Results[i]
return expr, info.TypeOf(expr)
}
if len(n.Results) == 1 && sig.Results().Len() > 1 {
// Spread return: return f() where f.Results > 1.
tuple := info.TypeOf(n.Results[0]).(*types.Tuple)
argType = func(i int) types.Type {
return tuple.At(i).Type()
argInfo = func(i int) (ast.Expr, types.Type) {
return nil, tuple.At(i).Type()
}
}
for i := 0; i < sig.Results().Len(); i++ {
if !trivialConversion(argType(i), sig.Results().At(i)) {
trivial = false
break
expr, typ := argInfo(i)
var flags returnOperandFlags
if typ == types.Typ[types.UntypedNil] { // untyped nil is preserved by go/types
flags |= untypedNilResult
}
if !trivialConversion(info.Types[expr].Value, typ, sig.Results().At(i).Type()) {
flags |= nonTrivialResult
}
resultInfo = append(resultInfo, flags)
}
} else if sig.Results().Len() > 0 {
hasBareReturn = true
}
if trivial {
trivialReturns++
}
returnInfo = append(returnInfo, resultInfo)
}
return true
})
@ -353,8 +363,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
Effects: effects,
HasDefer: hasDefer,
HasBareReturn: hasBareReturn,
TotalReturns: totalReturns,
TrivialReturns: trivialReturns,
Returns: returnInfo,
Labels: labels,
Falcon: falcon,
}}, nil

Просмотреть файл

@ -317,7 +317,7 @@ func TestFalconComplex(t *testing.T) {
"Complex arithmetic (good).",
`func f(re, im float64, z complex128) byte { return "x"[int(real(complex(re, im)*complex(re, -im)-z))] }`,
`func _() { f(1, 2, 5+0i) }`,
`func _() { _ = "x"[int(real(complex(float64(1), float64(2))*complex(float64(1), -float64(2))-(5+0i)))] }`,
`func _() { _ = "x"[int(real(complex(1, 2)*complex(1, -2)-(5+0i)))] }`,
},
{
"Complex arithmetic (bad).",

Просмотреть файл

@ -40,6 +40,8 @@ type Caller struct {
enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any
}
type unit struct{} // for representing sets as maps
// Inline inlines the called function (callee) into the function call (caller)
// and returns the updated, formatted content of the caller source file.
//
@ -109,51 +111,56 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte,
// This elision is only safe when the ExprStmt is beneath a
// BlockStmt, CaseClause.Body, or CommClause.Body;
// (see "statement theory").
elideBraces := false
if newBlock, ok := res.new.(*ast.BlockStmt); ok {
i := nodeIndex(caller.path, res.old)
parent := caller.path[i+1]
var body []ast.Stmt
switch parent := parent.(type) {
case *ast.BlockStmt:
body = parent.List
case *ast.CommClause:
body = parent.Body
case *ast.CaseClause:
body = parent.Body
}
if body != nil {
callerNames := declares(body)
//
// The inlining analysis may have already determined that eliding braces is
// safe. Otherwise, we analyze its safety here.
elideBraces := res.elideBraces
if !elideBraces {
if newBlock, ok := res.new.(*ast.BlockStmt); ok {
i := nodeIndex(caller.path, res.old)
parent := caller.path[i+1]
var body []ast.Stmt
switch parent := parent.(type) {
case *ast.BlockStmt:
body = parent.List
case *ast.CommClause:
body = parent.Body
case *ast.CaseClause:
body = parent.Body
}
if body != nil {
callerNames := declares(body)
// If BlockStmt is a function body,
// include its receiver, params, and results.
addFieldNames := func(fields *ast.FieldList) {
if fields != nil {
for _, field := range fields.List {
for _, id := range field.Names {
callerNames[id.Name] = true
// If BlockStmt is a function body,
// include its receiver, params, and results.
addFieldNames := func(fields *ast.FieldList) {
if fields != nil {
for _, field := range fields.List {
for _, id := range field.Names {
callerNames[id.Name] = true
}
}
}
}
}
switch f := caller.path[i+2].(type) {
case *ast.FuncDecl:
addFieldNames(f.Recv)
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
case *ast.FuncLit:
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
}
switch f := caller.path[i+2].(type) {
case *ast.FuncDecl:
addFieldNames(f.Recv)
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
case *ast.FuncLit:
addFieldNames(f.Type.Params)
addFieldNames(f.Type.Results)
}
if len(callerLabels(caller.path)) > 0 {
// TODO(adonovan): be more precise and reject
// only forward gotos across the inlined block.
logf("keeping block braces: caller uses control labels")
} else if intersects(declares(newBlock.List), callerNames) {
logf("keeping block braces: avoids name conflict")
} else {
elideBraces = true
if len(callerLabels(caller.path)) > 0 {
// TODO(adonovan): be more precise and reject
// only forward gotos across the inlined block.
logf("keeping block braces: caller uses control labels")
} else if intersects(declares(newBlock.List), callerNames) {
logf("keeping block braces: avoids name conflict")
} else {
elideBraces = true
}
}
}
}
@ -307,9 +314,23 @@ type newImport struct {
type result struct {
newImports []newImport
old, new ast.Node // e.g. replace call expr by callee function body expression
// If elideBraces is set, old is an ast.Stmt and new is an ast.BlockStmt to
// be spliced in. This allows the inlining analysis to assert that inlining
// the block is OK; if elideBraces is unset and old is an ast.Stmt and new is
// an ast.BlockStmt, braces may still be elided if the post-processing
// analysis determines that it is safe to do so.
//
// Ideally, it would not be necessary for the inlining analysis to "reach
// through" to the post-processing pass in this way. Instead, inlining could
// just set old to be an ast.BlockStmt and rewrite the entire BlockStmt, but
// unfortunately in order to preserve comments, it is important that inlining
// replace as little syntax as possible.
elideBraces bool
old, new ast.Node // e.g. replace call expr by callee function body expression
}
type logger = func(string, ...any)
// inline returns a pair of an old node (the call, or something
// enclosing it) and a new node (its replacement, which may be a
// combination of caller, callee, and new nodes), along with the set
@ -329,7 +350,7 @@ type result struct {
// candidate for evaluating an alternative fully self-contained tree
// representation, such as any proposed solution to #20744, or even
// dst or some private fork of go/ast.)
func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*result, error) {
func inline(logf logger, caller *Caller, callee *gobCallee) (*result, error) {
checkInfoFields(caller.Info)
// Inlining of dynamic calls is not currently supported,
@ -675,7 +696,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
updateCalleeParams(calleeDecl, params)
// Create a var (param = arg; ...) decl for use by some strategies.
bindingDeclStmt := createBindingDecl(logf, caller, args, calleeDecl, callee.Results)
bindingDecl := createBindingDecl(logf, caller, args, calleeDecl, callee.Results)
var remainingArgs []ast.Expr
for _, arg := range args {
@ -797,11 +818,11 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
len(calleeDecl.Body.List[0].(*ast.ReturnStmt).Results) > 0 { // not a bare return
results := calleeDecl.Body.List[0].(*ast.ReturnStmt).Results
context := callContext(caller.path)
parent, grandparent := callContext(caller.path)
// statement context
if stmt, ok := context.(*ast.ExprStmt); ok &&
(!needBindingDecl || bindingDeclStmt != nil) {
if stmt, ok := parent.(*ast.ExprStmt); ok &&
(!needBindingDecl || bindingDecl != nil) {
logf("strategy: reduce stmt-context call to { return exprs }")
clearPositions(calleeDecl.Body)
@ -817,7 +838,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
res.old = stmt
res.new = &ast.BlockStmt{
List: []ast.Stmt{
bindingDeclStmt,
bindingDecl.stmt,
&ast.ExprStmt{X: results[0]},
},
}
@ -841,7 +862,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
// Reduces to: { var (bindings); _, _ = exprs }
res.new = &ast.BlockStmt{
List: []ast.Stmt{
bindingDeclStmt,
bindingDecl.stmt,
discard,
},
}
@ -850,16 +871,48 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
return res, nil
}
// Assignment context.
//
// If there is no binding decl, or if the binding decl declares no names,
// an assignment a, b := f() can be reduced to a, b := x, y.
if stmt, ok := parent.(*ast.AssignStmt); ok &&
is[*ast.BlockStmt](grandparent) &&
(!needBindingDecl || (bindingDecl != nil && len(bindingDecl.names) == 0)) {
// Reduces to: { var (bindings); lhs... := rhs... }
if newStmts, ok := assignStmts(logf, caller, stmt, callee, results); ok {
logf("strategy: reduce assign-context call to { return exprs }")
clearPositions(calleeDecl.Body)
block := &ast.BlockStmt{
List: newStmts,
}
if needBindingDecl {
block.List = prepend(bindingDecl.stmt, block.List...)
}
// assignStmts does not introduce new bindings, and replacing an
// assignment only works if the replacement occurs in the same scope.
// Therefore, we must ensure that braces are elided.
res.elideBraces = true
res.old = stmt
res.new = block
return res, nil
}
}
// expression context
if !needBindingDecl {
clearPositions(calleeDecl.Body)
anyNonTrivialReturns := hasNonTrivialReturn(callee.Returns)
if callee.NumResults == 1 {
logf("strategy: reduce expr-context call to { return expr }")
// (includes some simple tail-calls)
// Make implicit return conversion explicit.
if callee.TrivialReturns < callee.TotalReturns {
if anyNonTrivialReturns {
results[0] = convert(calleeDecl.Type.Results.List[0].Type, results[0])
}
@ -867,7 +920,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
res.new = results[0]
return res, nil
} else if callee.TrivialReturns == callee.TotalReturns {
} else if !anyNonTrivialReturns {
logf("strategy: reduce spread-context call to { return expr }")
// There is no general way to reify conversions in a spread
// return, hence the requirement above.
@ -885,8 +938,8 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
// printf(f())
// or spread return statement:
// return f()
res.old = context
switch context := context.(type) {
res.old = parent
switch context := parent.(type) {
case *ast.AssignStmt:
// Inv: the call must be in Rhs[0], not Lhs.
assign := shallowCopy(context)
@ -939,18 +992,19 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
// TODO(adonovan): add a strategy for a 'void tail
// call', i.e. a call statement prior to an (explicit
// or implicit) return.
if ret, ok := callContext(caller.path).(*ast.ReturnStmt); ok &&
parent, _ := callContext(caller.path)
if ret, ok := parent.(*ast.ReturnStmt); ok &&
len(ret.Results) == 1 &&
tailCallSafeReturn(caller, calleeSymbol, callee) &&
!callee.HasBareReturn &&
(!needBindingDecl || bindingDeclStmt != nil) &&
(!needBindingDecl || bindingDecl != nil) &&
!hasLabelConflict(caller.path, callee.Labels) &&
allResultsUnreferenced {
logf("strategy: reduce tail-call")
body := calleeDecl.Body
clearPositions(body)
if needBindingDecl {
body.List = prepend(bindingDeclStmt, body.List...)
body.List = prepend(bindingDecl.stmt, body.List...)
}
res.old = ret
res.new = body
@ -974,16 +1028,16 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
// or replaced by a binding decl,
// - caller ExprStmt is in unrestricted statement context.
if stmt := callStmt(caller.path, true); stmt != nil &&
(!needBindingDecl || bindingDeclStmt != nil) &&
(!needBindingDecl || bindingDecl != nil) &&
!callee.HasDefer &&
!hasLabelConflict(caller.path, callee.Labels) &&
callee.TotalReturns == 0 {
len(callee.Returns) == 0 {
logf("strategy: reduce stmt-context call to { stmts }")
body := calleeDecl.Body
var repl ast.Stmt = body
clearPositions(repl)
if needBindingDecl {
body.List = prepend(bindingDeclStmt, body.List...)
body.List = prepend(bindingDecl.stmt, body.List...)
}
res.old = stmt
res.new = repl
@ -1036,10 +1090,10 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
//
// TODO(adonovan): relax the allResultsUnreferenced requirement
// by adding a parameter-only (no named results) binding decl.
if bindingDeclStmt != nil && allResultsUnreferenced {
if bindingDecl != nil && allResultsUnreferenced {
funcLit.Type.Params.List = nil
remainingArgs = nil
funcLit.Body.List = prepend(bindingDeclStmt, funcLit.Body.List...)
funcLit.Body.List = prepend(bindingDecl.stmt, funcLit.Body.List...)
}
// Emit a new call to a function literal in place of
@ -1339,7 +1393,7 @@ next:
// other arguments are given explicit types in either
// a binding decl or when using the literalization
// strategy.
if len(param.info.Refs) > 0 && !trivialConversion(args[i].typ, params[i].obj) {
if len(param.info.Refs) > 0 && !trivialConversion(args[i].constant, args[i].typ, params[i].obj.Type()) {
arg.expr = convert(params[i].fieldType, arg.expr)
logf("param %q: adding explicit %s -> %s conversion around argument",
param.info.Name, args[i].typ, params[i].obj.Type())
@ -1609,6 +1663,13 @@ func updateCalleeParams(calleeDecl *ast.FuncDecl, params []*parameter) {
calleeDecl.Type.Params.List = newParams
}
// bindingDeclInfo records information about the binding decl produced by
// createBindingDecl.
type bindingDeclInfo struct {
names map[string]bool // names bound by the binding decl; possibly empty
stmt ast.Stmt // the binding decl itself
}
// createBindingDecl constructs a "binding decl" that implements
// parameter assignment and declares any named result variables
// referenced by the callee. It returns nil if there were no
@ -1648,7 +1709,7 @@ func updateCalleeParams(calleeDecl *ast.FuncDecl, params []*parameter) {
//
// Strategies may impose additional checks on return
// conversions, labels, defer, etc.
func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) ast.Stmt {
func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) *bindingDeclInfo {
// Spread calls are tricky as they may not align with the
// parameters' field groupings nor types.
// For example, given
@ -1667,8 +1728,8 @@ func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argume
}
var (
specs []ast.Spec
shadowed = make(map[string]bool) // names defined by previous specs
specs []ast.Spec
names = make(map[string]bool) // names defined by previous specs
)
// shadow reports whether any name referenced by spec is
// shadowed by a name declared by a previous spec (since,
@ -1688,14 +1749,14 @@ func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argume
}
freeishNames(free, spec.Type)
for name := range free {
if shadowed[name] {
if names[name] {
logf("binding decl would shadow free name %q", name)
return true
}
}
for _, id := range spec.Names {
if id.Name != "_" {
shadowed[id.Name] = true
names[id.Name] = true
}
}
return false
@ -1770,7 +1831,7 @@ func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argume
},
}
logf("binding decl: %s", debugFormatNode(caller.Fset, stmt))
return stmt
return &bindingDeclInfo{names: names, stmt: stmt}
}
// lookup does a symbol lookup in the lexical environment of the caller.
@ -2170,16 +2231,20 @@ func isPkgLevel(obj types.Object) bool {
return obj.Pkg().Scope().Lookup(obj.Name()) == obj
}
// callContext returns the node immediately enclosing the call
// callContext returns the two nodes immediately enclosing the call
// (specified as a PathEnclosingInterval), ignoring parens.
func callContext(callPath []ast.Node) ast.Node {
func callContext(callPath []ast.Node) (parent, grandparent ast.Node) {
_ = callPath[0].(*ast.CallExpr) // sanity check
for _, n := range callPath[1:] {
if !is[*ast.ParenExpr](n) {
return n
if parent == nil {
parent = n
} else {
return parent, n
}
}
}
return nil
return parent, nil
}
// hasLabelConflict reports whether the set of labels of the function
@ -2243,7 +2308,8 @@ func callerFunc(callPath []ast.Node) ast.Node {
// in a restricted context (such as "if f(); cond {") where it cannot
// be replaced by an arbitrary statement. (See "statement theory".)
func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt {
stmt, ok := callContext(callPath).(*ast.ExprStmt)
parent, _ := callContext(callPath)
stmt, ok := parent.(*ast.ExprStmt)
if ok && unrestricted {
switch callPath[nodeIndex(callPath, stmt)+1].(type) {
case *ast.LabeledStmt,
@ -2417,11 +2483,16 @@ func clearPositions(root ast.Node) {
fields := v.Type().NumField()
for i := 0; i < fields; i++ {
f := v.Field(i)
// Clearing Pos arbitrarily is destructive,
// as its presence may be semantically significant
// (e.g. CallExpr.Ellipsis, TypeSpec.Assign)
// or affect formatting preferences (e.g. GenDecl.Lparen).
//
// Note: for proper formatting, it may be necessary to be selective
// about which positions we set to 1 vs which we set to token.NoPos.
// (e.g. we can set most to token.NoPos, save the few that are
// significant).
if f.Type() == posType {
// Clearing Pos arbitrarily is destructive,
// as its presence may be semantically significant
// (e.g. CallExpr.Ellipsis, TypeSpec.Assign)
// or affect formatting preferences (e.g. GenDecl.Lparen).
if f.Interface() != token.NoPos {
f.Set(reflect.ValueOf(token.Pos(1)))
}
@ -2653,11 +2724,328 @@ func declares(stmts []ast.Stmt) map[string]bool {
return names
}
// assignStmts rewrites a statement assigning the results of a call into zero
// or more statements that assign its return operands, or (nil, false) if no
// such rewrite is possible. The set of bindings created by the result of
// assignStmts is the same as the set of bindings created by the callerStmt.
//
// The callee must contain exactly one return statement.
//
// This is (once again) a surprisingly complex task. For example, depending on
// types and existing bindings, the assignment
//
// a, b := f()
//
// could be rewritten as:
//
// a, b := 1, 2
//
// but may need to be written as:
//
// a, b := int8(1), int32(2)
//
// In the case where the return statement within f is a spread call to another
// function g(), we cannot explicitly convert the return values inline, and so
// it may be necessary to split the declaration and assignment of variables
// into separate statements:
//
// a, b := g()
//
// or
//
// var a int32
// a, b = g()
//
// or
//
// var (
// a int8
// b int32
// )
// a, b = g()
//
// Note: assignStmts may return (nil, true) if it determines that the rewritten
// assignment consists only of _ = nil assignments.
func assignStmts(logf logger, caller *Caller, callerStmt *ast.AssignStmt, callee *gobCallee, returnOperands []ast.Expr) ([]ast.Stmt, bool) {
assert(len(callee.Returns) == 1, "unexpected multiple returns")
resultInfo := callee.Returns[0]
// When constructing assign statements, we need to make sure that we don't
// modify types on the left-hand side, such as would happen if the type of a
// RHS expression does not match the corresponding LHS type at the caller
// (due to untyped conversion or interface widening).
//
// This turns out to be remarkably tricky to handle correctly.
//
// Substrategies below are labeled as `Substrategy <name>:`.
// Collect LHS information.
var (
lhs []ast.Expr // shallow copy of the LHS slice, for mutation
defs = make([]*ast.Ident, len(callerStmt.Lhs)) // indexes in lhs of defining identifiers
blanks = make([]bool, len(callerStmt.Lhs)) // indexes in lhs of blank identifiers
byType typeutil.Map // map of distinct types -> indexes, for writing specs later
)
for i, expr := range callerStmt.Lhs {
lhs = append(lhs, expr)
if name, ok := expr.(*ast.Ident); ok {
if name.Name == "_" {
blanks[i] = true
continue // no type
}
if obj, isDef := caller.Info.Defs[name]; isDef {
defs[i] = name
typ := obj.Type()
idxs, _ := byType.At(typ).([]int)
idxs = append(idxs, i)
byType.Set(typ, idxs)
}
}
}
// Collect RHS information
//
// The RHS is either a parallel assignment or spread assignment, but by
// looping over both callerStmt.Rhs and returnOperands we handle both.
var (
rhs []ast.Expr // new RHS of assignment, owned by the inliner
callIdx = -1 // index of the call among the original RHS
nilBlankAssigns = make(map[int]unit) // indexes in rhs of _ = nil assignments, which can be deleted
freeNames = make(map[string]bool) // free(ish) names among rhs expressions
nonTrivial = make(map[int]bool) // indexes in rhs of nontrivial result conversions
)
for i, expr := range callerStmt.Rhs {
if expr == caller.Call {
assert(callIdx == -1, "malformed (duplicative) AST")
callIdx = i
for j, returnOperand := range returnOperands {
freeishNames(freeNames, returnOperand)
rhs = append(rhs, returnOperand)
if resultInfo[j]&nonTrivialResult != 0 {
nonTrivial[i+j] = true
}
if blanks[i+j] && resultInfo[j]&untypedNilResult != 0 {
nilBlankAssigns[i+j] = unit{}
}
}
} else {
// We must clone before clearing positions, since e came from the caller.
expr = internalastutil.CloneNode(expr)
clearPositions(expr)
freeishNames(freeNames, expr)
rhs = append(rhs, expr)
}
}
assert(callIdx >= 0, "failed to find call in RHS")
// Substrategy "splice": Check to see if we can simply splice in the result
// expressions from the callee, such as simplifying
//
// x, y := f()
//
// to
//
// x, y := e1, e2
//
// where the types of x and y match the types of e1 and e2.
//
// This works as long as we don't need to write any additional type
// information.
if callerStmt.Tok == token.ASSIGN && // LHS types already determined before call
len(nonTrivial) == 0 { // no non-trivial conversions to worry about
logf("substrategy: slice assignment")
return []ast.Stmt{&ast.AssignStmt{
Lhs: lhs,
Tok: callerStmt.Tok,
TokPos: callerStmt.TokPos,
Rhs: rhs,
}}, true
}
// Inlining techniques below will need to write type information in order to
// preserve the correct types of LHS identifiers.
//
// writeType is a simple helper to write out type expressions.
// TODO(rfindley):
// 1. handle qualified type names (potentially adding new imports)
// 2. expand this to handle more type expressions.
// 3. refactor to share logic with callee rewriting.
universeAny := types.Universe.Lookup("any")
typeExpr := func(typ types.Type, shadows ...map[string]bool) ast.Expr {
var typeName string
switch typ := typ.(type) {
case *types.Basic:
typeName = typ.Name()
case interface{ Obj() *types.TypeName }: // Named, Alias, TypeParam
typeName = typ.Obj().Name()
}
// Special case: check for universe "any".
// TODO(golang/go#66921): this may become unnecessary if any becomes a proper alias.
if typ == universeAny.Type() {
typeName = "any"
}
if typeName == "" {
return nil
}
for _, shadow := range shadows {
if shadow[typeName] {
logf("cannot write shadowed type name %q", typeName)
return nil
}
}
obj, _ := caller.lookup(typeName).(*types.TypeName)
if obj != nil && types.Identical(obj.Type(), typ) {
return ast.NewIdent(typeName)
}
return nil
}
// Substrategy "spread": in the case of a spread call (func f() (T1, T2) return
// g()), since we didn't hit the 'splice' substrategy, there must be some
// non-declaring expression on the LHS. Simplify this by pre-declaring
// variables, rewriting
//
// x, y := f()
//
// to
//
// var x int
// x, y = g()
//
// Which works as long as the predeclared variables do not overlap with free
// names on the RHS.
if len(rhs) != len(lhs) {
assert(len(rhs) == 1 && len(returnOperands) == 1, "expected spread call")
for _, id := range defs {
if id != nil && freeNames[id.Name] {
// By predeclaring variables, we're changing them to be in scope of the
// RHS. We can't do this if their names are free on the RHS.
return nil, false
}
}
// Write out the specs, being careful to avoid shadowing free names in
// their type expressions.
var (
specs []ast.Spec
specIdxs []int
shadow = make(map[string]bool)
)
failed := false
byType.Iterate(func(typ types.Type, v any) {
if failed {
return
}
idxs := v.([]int)
specIdxs = append(specIdxs, idxs[0])
texpr := typeExpr(typ, shadow)
if texpr == nil {
failed = true
return
}
spec := &ast.ValueSpec{
Type: texpr,
}
for _, idx := range idxs {
spec.Names = append(spec.Names, ast.NewIdent(defs[idx].Name))
}
specs = append(specs, spec)
})
if failed {
return nil, false
}
logf("substrategy: spread assignment")
return []ast.Stmt{
&ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: specs,
},
},
&ast.AssignStmt{
Lhs: callerStmt.Lhs,
Tok: token.ASSIGN,
Rhs: returnOperands,
},
}, true
}
assert(len(lhs) == len(rhs), "mismatching LHS and RHS")
// Substrategy "convert": write out RHS expressions with explicit type conversions
// as necessary, rewriting
//
// x, y := f()
//
// to
//
// x, y := 1, int32(2)
//
// As required to preserve types.
//
// In the special case of _ = nil, which is disallowed by the type checker
// (since nil has no default type), we delete the assignment.
var origIdxs []int // maps back to original indexes after lhs and rhs are pruned
i := 0
for j := range lhs {
if _, ok := nilBlankAssigns[j]; !ok {
lhs[i] = lhs[j]
rhs[i] = rhs[j]
origIdxs = append(origIdxs, j)
i++
}
}
lhs = lhs[:i]
rhs = rhs[:i]
if len(lhs) == 0 {
logf("trivial assignment after pruning nil blanks assigns")
// After pruning, we have no remaining assignments.
// Signal this by returning a non-nil slice of statements.
return nil, true
}
// Write out explicit conversions as necessary.
//
// A conversion is necessary if the LHS is being defined, and the RHS return
// involved a nontrivial implicit conversion.
for i, expr := range rhs {
idx := origIdxs[i]
if nonTrivial[idx] && defs[idx] != nil {
typ := caller.Info.TypeOf(lhs[i])
texpr := typeExpr(typ)
if texpr == nil {
return nil, false
}
if _, ok := texpr.(*ast.StarExpr); ok {
// TODO(rfindley): is this necessary? Doesn't the formatter add these parens?
texpr = &ast.ParenExpr{X: texpr} // *T -> (*T) so that (*T)(x) is valid
}
rhs[i] = &ast.CallExpr{
Fun: texpr,
Args: []ast.Expr{expr},
}
}
}
logf("substrategy: convert assignment")
return []ast.Stmt{&ast.AssignStmt{
Lhs: lhs,
Tok: callerStmt.Tok,
Rhs: rhs,
}}, true
}
// tailCallSafeReturn reports whether the callee's return statements may be safely
// used to return from the function enclosing the caller (which must exist).
func tailCallSafeReturn(caller *Caller, calleeSymbol *types.Func, callee *gobCallee) bool {
// It is safe if all callee returns involve only trivial conversions.
if callee.TrivialReturns == callee.TotalReturns {
if !hasNonTrivialReturn(callee.Returns) {
return true
}
@ -2683,3 +3071,16 @@ loop:
calleeResults := calleeSymbol.Type().(*types.Signature).Results()
return types.Identical(callerResults, calleeResults)
}
// hasNonTrivialReturn reports whether any of the returns involve a nontrivial
// implicit conversion of a result expression.
func hasNonTrivialReturn(returnInfo [][]returnOperandFlags) bool {
for _, resultInfo := range returnInfo {
for _, r := range resultInfo {
if r&nonTrivialResult != 0 {
return true
}
}
}
return false
}

Просмотреть файл

@ -649,26 +649,26 @@ func TestSubstitution(t *testing.T) {
func TestTailCallStrategy(t *testing.T) {
runTests(t, []testcase{
{
"Tail call.",
"simple",
`func f() int { return 1 }`,
`func _() int { return f() }`,
`func _() int { return 1 }`,
},
{
"Void tail call.",
"void",
`func f() { println() }`,
`func _() { f() }`,
`func _() { println() }`,
},
{
"Void tail call with defer.", // => literalized
"void with defer", // => literalized
`func f() { defer f(); println() }`,
`func _() { f() }`,
`func _() { func() { defer f(); println() }() }`,
},
// Tests for issue #63336:
{
"Tail call with non-trivial return conversion (caller.sig = callee.sig).",
"non-trivial return conversion (caller.sig = callee.sig)",
`func f() error { if true { return nil } else { return e } }; var e struct{error}`,
`func _() error { return f() }`,
`func _() error {
@ -680,7 +680,7 @@ func TestTailCallStrategy(t *testing.T) {
}`,
},
{
"Tail call with non-trivial return conversion (caller.sig != callee.sig).",
"non-trivial return conversion (caller.sig != callee.sig)",
`func f() error { return E{} }; type E struct{error}`,
`func _() any { return f() }`,
`func _() any { return error(E{}) }`,
@ -724,11 +724,169 @@ func TestSpreadCalls(t *testing.T) {
`func _() (int, error) { return f() }`,
`func _() (int, error) { return 0, nil }`,
},
})
}
func TestAssignmentCallStrategy(t *testing.T) {
runTests(t, []testcase{
{
"Implicit return conversions defeat reduction of spread returns, for now.",
"splice: basic",
`func f(x int) (int, int) { return x, 2 }`,
`func _() { x, y := f(1); _, _ = x, y }`,
`func _() { x, y := 1, 2; _, _ = x, y }`,
},
{
"spread: basic",
`func f(x int) (any, any) { return g() }; func g() (error, error) { return nil, nil }`,
`func _() {
var x any
x, y := f(0)
_, _ = x, y
}`,
`func _() {
var x any
var y any
x, y = g()
_, _ = x, y
}`,
},
{
"spread: free var conflict",
`func f(x int) (any, any) { return g(x) }; func g(x int) (int, int) { return x, x }`,
`func _() {
y := 2
{
var x any
x, y := f(y)
_, _ = x, y
}
}`,
`func _() {
y := 2
{
var x any
x, y := func() (any, any) { return g(y) }()
_, _ = x, y
}
}`,
},
{
"convert: basic",
`func f(x int) (int32, int8) { return 1, 2 }`,
`func _() {
var x int32
x, y := f(0)
_, _ = x, y
}`,
`func _() {
var x int32
x, y := 1, int8(2)
_, _ = x, y
}`,
},
{
"convert: rune and byte",
`func f(x int) (rune, byte) { return 0, 0 }`,
`func _() {
x, y := f(0)
_, _ = x, y
}`,
`func _() {
x, y := rune(0), byte(0)
_, _ = x, y
}`,
},
{
"convert: interface conversions",
`func f(x int) (_, _ error) { return nil, nil }`,
`func _() { _, _ = f(0) }`,
`func _() { _, _ = func() (_, _ error) { return nil, nil }() }`,
`func _() {
x, y := f(0)
_, _ = x, y
}`,
`func _() {
x, y := error(nil), error(nil)
_, _ = x, y
}`,
},
{
"convert: implicit nil conversions",
`func f(x int) (_, _ error) { return nil, nil }`,
`func _() { x, y := f(0); _, _ = x, y }`,
`func _() { x, y := error(nil), error(nil); _, _ = x, y }`,
},
{
"convert: pruning nil assignments left",
`func f(x int) (_, _ error) { return nil, nil }`,
`func _() { _, y := f(0); _ = y }`,
`func _() { y := error(nil); _ = y }`,
},
{
"convert: pruning nil assignments right",
`func f(x int) (_, _ error) { return nil, nil }`,
`func _() { x, _ := f(0); _ = x }`,
`func _() { x := error(nil); _ = x }`,
},
{
"convert: partial assign",
`func f(x int) (_, _ error) { return nil, nil }`,
`func _() {
var x error
x, y := f(0)
_, _ = x, y
}`,
`func _() {
var x error
x, y := nil, error(nil)
_, _ = x, y
}`,
},
{
"convert: single assignment left",
`func f() int { return 0 }`,
`func _() {
x, y := f(), "hello"
_, _ = x, y
}`,
`func _() {
x, y := 0, "hello"
_, _ = x, y
}`,
},
{
"convert: single assignment left with conversion",
`func f() int32 { return 0 }`,
`func _() {
x, y := f(), "hello"
_, _ = x, y
}`,
`func _() {
x, y := int32(0), "hello"
_, _ = x, y
}`,
},
{
"convert: single assignment right",
`func f() int32 { return 0 }`,
`func _() {
x, y := "hello", f()
_, _ = x, y
}`,
`func _() {
x, y := "hello", int32(0)
_, _ = x, y
}`,
},
{
"convert: single assignment middle",
`func f() int32 { return 0 }`,
`func _() {
x, y, z := "hello", f(), 1.56
_, _, _ = x, y, z
}`,
`func _() {
x, y, z := "hello", int32(0), 1.56
_, _, _ = x, y, z
}`,
},
})
}

Просмотреть файл

@ -8,6 +8,7 @@ package inline
import (
"go/ast"
"go/constant"
"go/token"
"go/types"
"reflect"
@ -64,8 +65,37 @@ func within(pos token.Pos, n ast.Node) bool {
// The reason for this check is that converting from A to B to C may
// yield a different result than converting A directly to C: consider
// 0 to int32 to any.
func trivialConversion(val types.Type, obj *types.Var) bool {
return types.Identical(types.Default(val), obj.Type())
//
// trivialConversion under-approximates trivial conversions, as unfortunately
// go/types does not record the type of an expression *before* it is implicitly
// converted, and therefore it cannot distinguish typed constant constant
// expressions from untyped constant expressions. For example, in the
// expression `c + 2`, where c is a uint32 constant, trivialConversion does not
// detect that the default type of this express is actually uint32, not untyped
// int.
//
// We could, of course, do better here by reverse engineering some of go/types'
// constant handling. That may or may not be worthwhile..
func trivialConversion(fromValue constant.Value, from, to types.Type) bool {
if fromValue != nil {
var defaultType types.Type
switch fromValue.Kind() {
case constant.Bool:
defaultType = types.Typ[types.Bool]
case constant.String:
defaultType = types.Typ[types.String]
case constant.Int:
defaultType = types.Typ[types.Int]
case constant.Float:
defaultType = types.Typ[types.Float64]
case constant.Complex:
defaultType = types.Typ[types.Complex128]
default:
return false
}
return types.Identical(defaultType, to)
}
return types.Identical(from, to)
}
func checkInfoFields(info *types.Info) {