From 586b21ae0964abe3f376aeb52d34f06ad27de00a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 2 Oct 2023 12:32:29 -0400 Subject: [PATCH] internal/refactor/inline: elide redundant braces When replacing a CallExpr beneath an ExprStmt, a reduction strategy may return a BlockStmt containing zero or more statements. This change eliminates the braces for the block when it is safe to do so, in other words when these three conditions are met: (a) the parent of the ExprStmt is an unrestricted statement context e.g. a block or the body of a case of a switch or select, but not, say "if f(); cond {". (b) there are no forward gotos in the caller that may jump across a declaration. (Currently we check for any control labels at all in the caller.) (c) there are no conflicts between names declared in the callee block and in the caller block. Plus tests. Also, a fix and test for a latent bug allowing reduction in a restricted "if stmt; expr" context. Fixes golang/go#63259 Change-Id: I558c75d8306dfd0679768cb4b3dbf05f14b23c39 Reviewed-on: https://go-review.googlesource.com/c/tools/+/532099 LUCI-TryBot-Result: Go LUCI Auto-Submit: Alan Donovan Reviewed-by: Robert Findley --- internal/refactor/inline/inline.go | 176 +++++++++++++-- internal/refactor/inline/inline_test.go | 211 +++++++++++------- .../inline/testdata/import-shadow.txtar | 8 +- .../refactor/inline/testdata/method.txtar | 8 +- .../inline/testdata/multistmt-body.txtar | 8 +- .../refactor/inline/testdata/tailcall.txtar | 24 +- internal/refactor/inline/util.go | 13 ++ 7 files changed, 322 insertions(+), 126 deletions(-) diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index ef2ce3aae87..4fe918b7c1d 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -99,6 +99,40 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte, res.new = &ast.ParenExpr{X: res.new.(ast.Expr)} } + // Some reduction strategies return a new block holding the + // callee's statements. The block's braces may be elided when + // there is no conflict between names declared in the block + // with those declared by the parent block, and no risk of + // a caller's goto jumping forward across a declaration. + // + // This elision is only safe when the ExprStmt is beneath a + // BlockStmt, CaseClause.Body, or CommClause.Body; + // (see "statement theory"). + elideBraces := false + if newBlock, ok := res.new.(*ast.BlockStmt); ok { + parent := caller.path[nodeIndex(caller.path, res.old)+1] + var body []ast.Stmt + switch parent := parent.(type) { + case *ast.BlockStmt: + body = parent.List + case *ast.CommClause: + body = parent.Body + case *ast.CaseClause: + body = parent.Body + } + if body != nil { + if len(callerLabels(caller.path)) > 0 { + // TODO(adonovan): be more precise and reject + // only forward gotos across the inlined block. + logf("keeping block braces: caller uses control labels") + } else if intersects(declares(newBlock.List), declares(body)) { + logf("keeping block braces: avoids name conflict") + } else { + elideBraces = true + } + } + } + // Don't call replaceNode(caller.File, res.old, res.new) // as it mutates the caller's syntax tree. // Instead, splice the file, replacing the extent of the "old" @@ -124,9 +158,16 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte, // Precise comment handling would make this a // non-issue. Formatting wouldn't really need a // FileSet at all. + mark := out.Len() if err := format.Node(&out, caller.Fset, res.new); err != nil { return nil, err } + if elideBraces { + // Overwrite unnecessary {...} braces with spaces. + // TODO(adonovan): less hacky solution. + out.Bytes()[mark] = ' ' + out.Bytes()[out.Len()-1] = ' ' + } out.Write(caller.Content[end:]) const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode) @@ -630,7 +671,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu logf("strategy: reduce call to empty body") // Evaluate the arguments for effects and delete the call entirely. - stmt := callStmt(caller.path) // cannot fail + stmt := callStmt(caller.path, false) // cannot fail res.old = stmt if nargs := len(remainingArgs); nargs > 0 { // Emit "_, _ = args" to discard results. @@ -862,9 +903,10 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu // - there is no label conflict between caller and callee // - all parameters and result vars can be eliminated // or replaced by a binding decl, + // - caller ExprStmt is in unrestricted statement context. // // If there is only a single statement, the braces are omitted. - if stmt := callStmt(caller.path); stmt != nil && + if stmt := callStmt(caller.path, true); stmt != nil && (!needBindingDecl || bindingDeclStmt != nil) && !callee.HasDefer && !hasLabelConflict(caller.path, callee.Labels) && @@ -876,7 +918,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu if needBindingDecl { body.List = prepend(bindingDeclStmt, body.List...) } - if len(body.List) == 1 { + if len(body.List) == 1 { // FIXME do this opt later repl = body.List[0] // singleton: omit braces } res.old = stmt @@ -2018,6 +2060,18 @@ func callContext(callPath []ast.Node) ast.Node { // enclosing the call (specified as a PathEnclosingInterval) // intersects with the set of callee labels. func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool { + labels := callerLabels(callPath) + for _, label := range calleeLabels { + if labels[label] { + return true // conflict + } + } + return false +} + +// callerLabels returns the set of control labels in the function (if +// any) enclosing the call (specified as a PathEnclosingInterval). +func callerLabels(callPath []ast.Node) map[string]bool { var callerBody *ast.BlockStmt switch f := callerFunc(callPath).(type) { case *ast.FuncDecl: @@ -2025,23 +2079,22 @@ func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool { case *ast.FuncLit: callerBody = f.Body } - conflict := false + var labels map[string]bool if callerBody != nil { ast.Inspect(callerBody, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncLit: return false // prune traversal case *ast.LabeledStmt: - for _, label := range calleeLabels { - if label == n.Label.Name { - conflict = true - } + if labels == nil { + labels = make(map[string]bool) } + labels[n.Label.Name] = true } return true }) } - return conflict + return labels } // callerFunc returns the innermost Func{Decl,Lit} node enclosing the @@ -2059,11 +2112,64 @@ func callerFunc(callPath []ast.Node) ast.Node { // callStmt reports whether the function call (specified // as a PathEnclosingInterval) appears within an ExprStmt, // and returns it if so. -func callStmt(callPath []ast.Node) *ast.ExprStmt { - stmt, _ := callContext(callPath).(*ast.ExprStmt) +// +// If unrestricted, callStmt returns nil if the ExprStmt f() appears +// in a restricted context (such as "if f(); cond {") where it cannot +// be replaced by an arbitrary statement. (See "statement theory".) +func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt { + stmt, ok := callContext(callPath).(*ast.ExprStmt) + if ok && unrestricted { + switch callPath[nodeIndex(callPath, stmt)+1].(type) { + case *ast.LabeledStmt, + *ast.BlockStmt, + *ast.CaseClause, + *ast.CommClause: + // unrestricted + default: + // TODO(adonovan): handle restricted + // XYZStmt.Init contexts (but not ForStmt.Post) + // by creating a block around the if/for/switch: + // "if f(); cond {" -> "{ stmts; if cond {" + + return nil // restricted + } + } return stmt } +// Statement theory +// +// These are all the places a statement may appear in the AST: +// +// LabeledStmt.Stmt Stmt -- any +// BlockStmt.List []Stmt -- any (but see switch/select) +// IfStmt.Init Stmt? -- simple +// IfStmt.Body BlockStmt +// IfStmt.Else Stmt? -- IfStmt or BlockStmt +// CaseClause.Body []Stmt -- any +// SwitchStmt.Init Stmt? -- simple +// SwitchStmt.Body BlockStmt -- CaseClauses only +// TypeSwitchStmt.Init Stmt? -- simple +// TypeSwitchStmt.Assign Stmt -- AssignStmt(TypeAssertExpr) or ExprStmt(TypeAssertExpr) +// TypeSwitchStmt.Body BlockStmt -- CaseClauses only +// CommClause.Comm Stmt? -- SendStmt or ExprStmt(UnaryExpr) or AssignStmt(UnaryExpr) +// CommClause.Body []Stmt -- any +// SelectStmt.Body BlockStmt -- CommClauses only +// ForStmt.Init Stmt? -- simple +// ForStmt.Post Stmt? -- simple +// ForStmt.Body BlockStmt +// RangeStmt.Body BlockStmt +// +// simple = AssignStmt | SendStmt | IncDecStmt | ExprStmt. +// +// A BlockStmt cannot replace an ExprStmt in +// {If,Switch,TypeSwitch}Stmt.Init or ForStmt.Post. +// That is allowed only within: +// LabeledStmt.Stmt Stmt +// BlockStmt.List []Stmt +// CaseClause.Body []Stmt +// CommClause.Body []Stmt + // replaceNode performs a destructive update of the tree rooted at // root, replacing each occurrence of "from" with "to". If to is nil and // the element is within a slice, the slice element is removed. @@ -2372,13 +2478,7 @@ func consistentOffsets(caller *Caller) bool { // ancestor of the CallExpr identified by its PathEnclosingInterval). func needsParens(callPath []ast.Node, old, new ast.Node) bool { // Find enclosing old node and its parent. - // TODO(adonovan): Use index[ast.Node]() in go1.20. - i := -1 - for i = range callPath { - if callPath[i] == old { - break - } - } + i := nodeIndex(callPath, old) if i == -1 { panic("not found") } @@ -2439,3 +2539,43 @@ func needsParens(callPath []ast.Node, old, new ast.Node) bool { } return false } + +func nodeIndex(nodes []ast.Node, n ast.Node) int { + // TODO(adonovan): Use index[ast.Node]() in go1.20. + for i, node := range nodes { + if node == n { + return i + } + } + return -1 +} + +// declares returns the set of lexical names declared by a +// sequence of statements from the same block, excluding sub-blocks. +// (Lexical names do not include control labels.) +func declares(stmts []ast.Stmt) map[string]bool { + names := make(map[string]bool) + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.DeclStmt: + for _, spec := range stmt.Decl.(*ast.GenDecl).Specs { + switch spec := spec.(type) { + case *ast.ValueSpec: + for _, id := range spec.Names { + names[id.Name] = true + } + case *ast.TypeSpec: + names[spec.Name.Name] = true + } + } + + case *ast.AssignStmt: + if stmt.Tok == token.DEFINE { + for _, lhs := range stmt.Lhs { + names[lhs.(*ast.Ident).Name] = true + } + } + } + } + return names +} diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go index b1e06aef5b1..e47c07625ed 100644 --- a/internal/refactor/inline/inline_test.go +++ b/internal/refactor/inline/inline_test.go @@ -378,10 +378,91 @@ func TestBasics(t *testing.T) { `func f(s string, i int) { print(s, s, i, i) }`, `func _() { f("hi", 0) }`, `func _() { + var s string = "hi" + print(s, s, 0, 0) +}`, + }, + }) +} + +func TestExprStmtReduction(t *testing.T) { + runTests(t, []testcase{ + { + "A call in an unrestricted ExprStmt may be replaced by the body stmts.", + `func f() { var _ = len("") }`, + `func _() { f() }`, + `func _() { var _ = len("") }`, + }, + { + "ExprStmts in the body of a switch case are unrestricted.", + `func f() { x := 1; print(x) }`, + `func _() { switch { case true: f() } }`, + `func _() { + switch { + case true: + x := 1 + print(x) + } +}`, + }, + { + "ExprStmts in the body of a select case are unrestricted.", + `func f() { x := 1; print(x) }`, + `func _() { select { default: f() } }`, + `func _() { + select { + default: + x := 1 + print(x) + } +}`, + }, + { + "Some ExprStmt contexts are restricted to simple statements.", + `func f() { var _ = len("") }`, + `func _(cond bool) { if f(); cond {} }`, + `func _(cond bool) { + if func() { var _ = len("") }(); cond { + } +}`, + }, + { + "Braces must be preserved to avoid a name conflict (decl before).", + `func f() { x := 1; print(x) }`, + `func _() { x := 2; print(x); f() }`, + `func _() { + x := 2 + print(x) + { + x := 1 + print(x) + } +}`, + }, + { + "Braces must be preserved to avoid a name conflict (decl after).", + `func f() { x := 1; print(x) }`, + `func _() { f(); x := 2; print(x) }`, + `func _() { + { + x := 1 + print(x) + } + x := 2 + print(x) +}`, + }, + { + "Braces must be preserved to avoid a forward jump across a decl.", + `func f() { x := 1; print(x) }`, + `func _() { goto label; f(); label: }`, + `func _() { + goto label { - var s string = "hi" - print(s, s, 0, 0) + x := 1 + print(x) } +label: }`, }, }) @@ -576,10 +657,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(x int) { x++ }`, `func _() { f(1) }`, `func _() { - { - var x int = 1 - x++ - } + var x int = 1 + x++ }`, }, { @@ -587,10 +666,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(w, x, y any, z int) { println(w, y, z) }; func g(int) int`, `func _() { f(g(0), g(1), g(2), g(3)) }`, `func _() { - { - var w, _ any = g(0), g(1) - println(w, any(g(2)), g(3)) - } + var w, _ any = g(0), g(1) + println(w, any(g(2)), g(3)) }`, }, { @@ -605,10 +682,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(x int) int { return <-h(g(2), x) }; func g(int) int; func h(int, int) chan int`, `func _() { f(g(1)) }`, `func _() { - { - var x int = g(1) - <-h(g(2), x) - } + var x int = g(1) + <-h(g(2), x) }`, }, { @@ -675,10 +750,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, c, b) }; func g(int) int`, `func _() { f(g(1), g(2), g(3)) }`, `func _() { - { - var a, b int = g(1), g(2) - print(a, g(3), b) - } + var a, b int = g(1), g(2) + print(a, g(3), b) }`, }, { @@ -698,10 +771,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d int) { print(a, c, b, d) }; func g(int) int; var x, y int`, `func _() { f(g(1), g(2), y, g(3)) }`, `func _() { - { - var a, b int = g(1), g(2) - print(a, y, b, g(3)) - } + var a, b int = g(1), g(2) + print(a, y, b, g(3)) }`, }, { @@ -709,10 +780,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d int) { print(a, c, b, d) }; func g(int) int; var x, y int`, `func _() { f(g(1), y, g(2), g(3)) }`, `func _() { - { - var a, b int = g(1), y - print(a, g(2), b, g(3)) - } + var a, b int = g(1), y + print(a, g(2), b, g(3)) }`, }, { @@ -732,10 +801,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, b, recover().(int), c) }; var x, y, z int`, `func _() { f(x, y, z) }`, `func _() { - { - var c int = z - print(x, y, recover().(int), c) - } + var c int = z + print(x, y, recover().(int), c) }`, }, { @@ -743,10 +810,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, b, recover().(int), c) }; func g(int) int; var x, y, z int`, `func _() { f(x, y, g(0)) }`, `func _() { - { - var a, b, c int = x, y, g(0) - print(a, b, recover().(int), c) - } + var a, b, c int = x, y, g(0) + print(a, b, recover().(int), c) }`, }, { @@ -754,10 +819,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d, e int) { print(b, a, c, e, d) }; func g(int) int; var x, y int`, `func _() { f(x, g(1), g(2), y, g(3)) }`, `func _() { - { - var a, b, c, d int = x, g(1), g(2), y - print(b, a, c, g(3), d) - } + var a, b, c, d int = x, g(1), g(2), y + print(b, a, c, g(3), d) }`, }, { @@ -788,10 +851,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(x, y int) { _ = &y }; func g(int) int`, `func _() { f(g(1), g(2)) }`, `func _() { - { - var _, y int = g(1), g(2) - _ = &y - } + var _, y int = g(1), g(2) + _ = &y }`, }, { @@ -802,10 +863,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(x, y int) { _ = x }; func g(int) int; var v int`, `func _() { f(v, g(2)) }`, `func _() { - { - var x, _ int = v, g(2) - _ = x - } + var x, _ int = v, g(2) + _ = x }`, }, { @@ -824,10 +883,8 @@ func TestNamedResultVars(t *testing.T) { `func f() (x int) { return g(x) }; func g(int) int`, `func _() { f() }`, `func _() { - { - var x int - g(x) - } + var x int + g(x) }`, }, { @@ -835,13 +892,11 @@ func TestNamedResultVars(t *testing.T) { `func f(y string) (x int) { return x+x+len(y+y) }`, `func _() { f(".") }`, `func _() { - { - var ( - y string = "." - x int - ) - _ = x + x + len(y+y) - } + var ( + y string = "." + x int + ) + _ = x + x + len(y+y) }`, }, @@ -850,13 +905,11 @@ func TestNamedResultVars(t *testing.T) { `func f(y string) (x string) { return x+y+y }`, `func _() { f(".") }`, `func _() { - { - var ( - y string = "." - x string - ) - _ = x + y + y - } + var ( + y string = "." + x string + ) + _ = x + y + y }`, }, { @@ -864,10 +917,8 @@ func TestNamedResultVars(t *testing.T) { `func f() (x int) { return x+x }`, `func _() { f() }`, `func _() { - { - var x int - _ = x + x - } + var x int + _ = x + x }`, }, { @@ -904,10 +955,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x int16) { y := x; _ = (*int16)(&y) }`, `func _() { f(1) }`, `func _() { - { - y := int16(1) - _ = (*int16)(&y) - } + y := int16(1) + _ = (*int16)(&y) }`, }, { @@ -915,10 +964,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x T) { y := x; _ = (*T)(&y) }; type T struct{}`, `func _() { f(struct{}{}) }`, `func _() { - { - y := T(struct{}{}) - _ = (*T)(&y) - } + y := T(struct{}{}) + _ = (*T)(&y) }`, }, { @@ -926,10 +973,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x T) { y := x; _ = (*T)(&y) }; type T = <-chan int; var ch chan int`, `func _() { f(ch) }`, `func _() { - { - y := T(ch) - _ = (*T)(&y) - } + y := T(ch) + _ = (*T)(&y) }`, }, { @@ -937,10 +982,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x *int) { y := x; _ = (**int)(&y) }`, `func _() { f(nil) }`, `func _() { - { - y := (*int)(nil) - _ = (**int)(&y) - } + y := (*int)(nil) + _ = (**int)(&y) }`, }, { diff --git a/internal/refactor/inline/testdata/import-shadow.txtar b/internal/refactor/inline/testdata/import-shadow.txtar index 5d4f9243c18..4188a52375d 100644 --- a/internal/refactor/inline/testdata/import-shadow.txtar +++ b/internal/refactor/inline/testdata/import-shadow.txtar @@ -87,10 +87,10 @@ import ( var x b.T func A(b int) { - { - b0.One() - b0.Two() - } //@ inline(re"F", fresult) + + b0.One() + b0.Two() + //@ inline(re"F", fresult) } -- d/d.go -- diff --git a/internal/refactor/inline/testdata/method.txtar b/internal/refactor/inline/testdata/method.txtar index fde9f366c12..b141b09d707 100644 --- a/internal/refactor/inline/testdata/method.txtar +++ b/internal/refactor/inline/testdata/method.txtar @@ -104,10 +104,10 @@ func (T) h() int { return 1 } func _() { var ptr *T - { - var _ T = *ptr - _ = 1 - } //@ inline(re"h", h) + + var _ T = *ptr + _ = 1 + //@ inline(re"h", h) } -- a/i.go -- diff --git a/internal/refactor/inline/testdata/multistmt-body.txtar b/internal/refactor/inline/testdata/multistmt-body.txtar index 1dc39c5c3b9..6bd0108e1fe 100644 --- a/internal/refactor/inline/testdata/multistmt-body.txtar +++ b/internal/refactor/inline/testdata/multistmt-body.txtar @@ -54,10 +54,10 @@ package a func _() { a := 1 - { - z := 1 - print(a + 2 + z) - } //@ inline(re"f", out2) + + z := 1 + print(a + 2 + z) + //@ inline(re"f", out2) } -- a/a3.go -- diff --git a/internal/refactor/inline/testdata/tailcall.txtar b/internal/refactor/inline/testdata/tailcall.txtar index 64f5f9735a0..53b6de367dd 100644 --- a/internal/refactor/inline/testdata/tailcall.txtar +++ b/internal/refactor/inline/testdata/tailcall.txtar @@ -36,19 +36,19 @@ start: package a func _() int { - { - total := 0 - start: - for i := 1; i <= 2; i++ { - total += i - if i == 6 { - goto start - } else if i == 7 { - return -1 - } + + total := 0 +start: + for i := 1; i <= 2; i++ { + total += i + if i == 6 { + goto start + } else if i == 7 { + return -1 } - return total - } //@ inline(re"sum", sum) + } + return total + //@ inline(re"sum", sum) } func sum(lo, hi int) int { diff --git a/internal/refactor/inline/util.go b/internal/refactor/inline/util.go index 82be0fb8ac6..6d8d3fac08f 100644 --- a/internal/refactor/inline/util.go +++ b/internal/refactor/inline/util.go @@ -90,3 +90,16 @@ func funcHasTypeParams(decl *ast.FuncDecl) bool { } return false } + +// intersects reports whether the maps' key sets intersect. +func intersects[K comparable, T1, T2 any](x map[K]T1, y map[K]T2) bool { + if len(x) > len(y) { + return intersects(y, x) + } + for k := range x { + if _, ok := y[k]; ok { + return true + } + } + return false +}