diff --git a/internal/refactor/inline/inline.go b/internal/refactor/inline/inline.go index ef2ce3aae87..4fe918b7c1d 100644 --- a/internal/refactor/inline/inline.go +++ b/internal/refactor/inline/inline.go @@ -99,6 +99,40 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte, res.new = &ast.ParenExpr{X: res.new.(ast.Expr)} } + // Some reduction strategies return a new block holding the + // callee's statements. The block's braces may be elided when + // there is no conflict between names declared in the block + // with those declared by the parent block, and no risk of + // a caller's goto jumping forward across a declaration. + // + // This elision is only safe when the ExprStmt is beneath a + // BlockStmt, CaseClause.Body, or CommClause.Body; + // (see "statement theory"). + elideBraces := false + if newBlock, ok := res.new.(*ast.BlockStmt); ok { + parent := caller.path[nodeIndex(caller.path, res.old)+1] + var body []ast.Stmt + switch parent := parent.(type) { + case *ast.BlockStmt: + body = parent.List + case *ast.CommClause: + body = parent.Body + case *ast.CaseClause: + body = parent.Body + } + if body != nil { + if len(callerLabels(caller.path)) > 0 { + // TODO(adonovan): be more precise and reject + // only forward gotos across the inlined block. + logf("keeping block braces: caller uses control labels") + } else if intersects(declares(newBlock.List), declares(body)) { + logf("keeping block braces: avoids name conflict") + } else { + elideBraces = true + } + } + } + // Don't call replaceNode(caller.File, res.old, res.new) // as it mutates the caller's syntax tree. // Instead, splice the file, replacing the extent of the "old" @@ -124,9 +158,16 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte, // Precise comment handling would make this a // non-issue. Formatting wouldn't really need a // FileSet at all. + mark := out.Len() if err := format.Node(&out, caller.Fset, res.new); err != nil { return nil, err } + if elideBraces { + // Overwrite unnecessary {...} braces with spaces. + // TODO(adonovan): less hacky solution. + out.Bytes()[mark] = ' ' + out.Bytes()[out.Len()-1] = ' ' + } out.Write(caller.Content[end:]) const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode) @@ -630,7 +671,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu logf("strategy: reduce call to empty body") // Evaluate the arguments for effects and delete the call entirely. - stmt := callStmt(caller.path) // cannot fail + stmt := callStmt(caller.path, false) // cannot fail res.old = stmt if nargs := len(remainingArgs); nargs > 0 { // Emit "_, _ = args" to discard results. @@ -862,9 +903,10 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu // - there is no label conflict between caller and callee // - all parameters and result vars can be eliminated // or replaced by a binding decl, + // - caller ExprStmt is in unrestricted statement context. // // If there is only a single statement, the braces are omitted. - if stmt := callStmt(caller.path); stmt != nil && + if stmt := callStmt(caller.path, true); stmt != nil && (!needBindingDecl || bindingDeclStmt != nil) && !callee.HasDefer && !hasLabelConflict(caller.path, callee.Labels) && @@ -876,7 +918,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu if needBindingDecl { body.List = prepend(bindingDeclStmt, body.List...) } - if len(body.List) == 1 { + if len(body.List) == 1 { // FIXME do this opt later repl = body.List[0] // singleton: omit braces } res.old = stmt @@ -2018,6 +2060,18 @@ func callContext(callPath []ast.Node) ast.Node { // enclosing the call (specified as a PathEnclosingInterval) // intersects with the set of callee labels. func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool { + labels := callerLabels(callPath) + for _, label := range calleeLabels { + if labels[label] { + return true // conflict + } + } + return false +} + +// callerLabels returns the set of control labels in the function (if +// any) enclosing the call (specified as a PathEnclosingInterval). +func callerLabels(callPath []ast.Node) map[string]bool { var callerBody *ast.BlockStmt switch f := callerFunc(callPath).(type) { case *ast.FuncDecl: @@ -2025,23 +2079,22 @@ func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool { case *ast.FuncLit: callerBody = f.Body } - conflict := false + var labels map[string]bool if callerBody != nil { ast.Inspect(callerBody, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncLit: return false // prune traversal case *ast.LabeledStmt: - for _, label := range calleeLabels { - if label == n.Label.Name { - conflict = true - } + if labels == nil { + labels = make(map[string]bool) } + labels[n.Label.Name] = true } return true }) } - return conflict + return labels } // callerFunc returns the innermost Func{Decl,Lit} node enclosing the @@ -2059,11 +2112,64 @@ func callerFunc(callPath []ast.Node) ast.Node { // callStmt reports whether the function call (specified // as a PathEnclosingInterval) appears within an ExprStmt, // and returns it if so. -func callStmt(callPath []ast.Node) *ast.ExprStmt { - stmt, _ := callContext(callPath).(*ast.ExprStmt) +// +// If unrestricted, callStmt returns nil if the ExprStmt f() appears +// in a restricted context (such as "if f(); cond {") where it cannot +// be replaced by an arbitrary statement. (See "statement theory".) +func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt { + stmt, ok := callContext(callPath).(*ast.ExprStmt) + if ok && unrestricted { + switch callPath[nodeIndex(callPath, stmt)+1].(type) { + case *ast.LabeledStmt, + *ast.BlockStmt, + *ast.CaseClause, + *ast.CommClause: + // unrestricted + default: + // TODO(adonovan): handle restricted + // XYZStmt.Init contexts (but not ForStmt.Post) + // by creating a block around the if/for/switch: + // "if f(); cond {" -> "{ stmts; if cond {" + + return nil // restricted + } + } return stmt } +// Statement theory +// +// These are all the places a statement may appear in the AST: +// +// LabeledStmt.Stmt Stmt -- any +// BlockStmt.List []Stmt -- any (but see switch/select) +// IfStmt.Init Stmt? -- simple +// IfStmt.Body BlockStmt +// IfStmt.Else Stmt? -- IfStmt or BlockStmt +// CaseClause.Body []Stmt -- any +// SwitchStmt.Init Stmt? -- simple +// SwitchStmt.Body BlockStmt -- CaseClauses only +// TypeSwitchStmt.Init Stmt? -- simple +// TypeSwitchStmt.Assign Stmt -- AssignStmt(TypeAssertExpr) or ExprStmt(TypeAssertExpr) +// TypeSwitchStmt.Body BlockStmt -- CaseClauses only +// CommClause.Comm Stmt? -- SendStmt or ExprStmt(UnaryExpr) or AssignStmt(UnaryExpr) +// CommClause.Body []Stmt -- any +// SelectStmt.Body BlockStmt -- CommClauses only +// ForStmt.Init Stmt? -- simple +// ForStmt.Post Stmt? -- simple +// ForStmt.Body BlockStmt +// RangeStmt.Body BlockStmt +// +// simple = AssignStmt | SendStmt | IncDecStmt | ExprStmt. +// +// A BlockStmt cannot replace an ExprStmt in +// {If,Switch,TypeSwitch}Stmt.Init or ForStmt.Post. +// That is allowed only within: +// LabeledStmt.Stmt Stmt +// BlockStmt.List []Stmt +// CaseClause.Body []Stmt +// CommClause.Body []Stmt + // replaceNode performs a destructive update of the tree rooted at // root, replacing each occurrence of "from" with "to". If to is nil and // the element is within a slice, the slice element is removed. @@ -2372,13 +2478,7 @@ func consistentOffsets(caller *Caller) bool { // ancestor of the CallExpr identified by its PathEnclosingInterval). func needsParens(callPath []ast.Node, old, new ast.Node) bool { // Find enclosing old node and its parent. - // TODO(adonovan): Use index[ast.Node]() in go1.20. - i := -1 - for i = range callPath { - if callPath[i] == old { - break - } - } + i := nodeIndex(callPath, old) if i == -1 { panic("not found") } @@ -2439,3 +2539,43 @@ func needsParens(callPath []ast.Node, old, new ast.Node) bool { } return false } + +func nodeIndex(nodes []ast.Node, n ast.Node) int { + // TODO(adonovan): Use index[ast.Node]() in go1.20. + for i, node := range nodes { + if node == n { + return i + } + } + return -1 +} + +// declares returns the set of lexical names declared by a +// sequence of statements from the same block, excluding sub-blocks. +// (Lexical names do not include control labels.) +func declares(stmts []ast.Stmt) map[string]bool { + names := make(map[string]bool) + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.DeclStmt: + for _, spec := range stmt.Decl.(*ast.GenDecl).Specs { + switch spec := spec.(type) { + case *ast.ValueSpec: + for _, id := range spec.Names { + names[id.Name] = true + } + case *ast.TypeSpec: + names[spec.Name.Name] = true + } + } + + case *ast.AssignStmt: + if stmt.Tok == token.DEFINE { + for _, lhs := range stmt.Lhs { + names[lhs.(*ast.Ident).Name] = true + } + } + } + } + return names +} diff --git a/internal/refactor/inline/inline_test.go b/internal/refactor/inline/inline_test.go index b1e06aef5b1..e47c07625ed 100644 --- a/internal/refactor/inline/inline_test.go +++ b/internal/refactor/inline/inline_test.go @@ -378,10 +378,91 @@ func TestBasics(t *testing.T) { `func f(s string, i int) { print(s, s, i, i) }`, `func _() { f("hi", 0) }`, `func _() { + var s string = "hi" + print(s, s, 0, 0) +}`, + }, + }) +} + +func TestExprStmtReduction(t *testing.T) { + runTests(t, []testcase{ + { + "A call in an unrestricted ExprStmt may be replaced by the body stmts.", + `func f() { var _ = len("") }`, + `func _() { f() }`, + `func _() { var _ = len("") }`, + }, + { + "ExprStmts in the body of a switch case are unrestricted.", + `func f() { x := 1; print(x) }`, + `func _() { switch { case true: f() } }`, + `func _() { + switch { + case true: + x := 1 + print(x) + } +}`, + }, + { + "ExprStmts in the body of a select case are unrestricted.", + `func f() { x := 1; print(x) }`, + `func _() { select { default: f() } }`, + `func _() { + select { + default: + x := 1 + print(x) + } +}`, + }, + { + "Some ExprStmt contexts are restricted to simple statements.", + `func f() { var _ = len("") }`, + `func _(cond bool) { if f(); cond {} }`, + `func _(cond bool) { + if func() { var _ = len("") }(); cond { + } +}`, + }, + { + "Braces must be preserved to avoid a name conflict (decl before).", + `func f() { x := 1; print(x) }`, + `func _() { x := 2; print(x); f() }`, + `func _() { + x := 2 + print(x) + { + x := 1 + print(x) + } +}`, + }, + { + "Braces must be preserved to avoid a name conflict (decl after).", + `func f() { x := 1; print(x) }`, + `func _() { f(); x := 2; print(x) }`, + `func _() { + { + x := 1 + print(x) + } + x := 2 + print(x) +}`, + }, + { + "Braces must be preserved to avoid a forward jump across a decl.", + `func f() { x := 1; print(x) }`, + `func _() { goto label; f(); label: }`, + `func _() { + goto label { - var s string = "hi" - print(s, s, 0, 0) + x := 1 + print(x) } +label: }`, }, }) @@ -576,10 +657,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(x int) { x++ }`, `func _() { f(1) }`, `func _() { - { - var x int = 1 - x++ - } + var x int = 1 + x++ }`, }, { @@ -587,10 +666,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(w, x, y any, z int) { println(w, y, z) }; func g(int) int`, `func _() { f(g(0), g(1), g(2), g(3)) }`, `func _() { - { - var w, _ any = g(0), g(1) - println(w, any(g(2)), g(3)) - } + var w, _ any = g(0), g(1) + println(w, any(g(2)), g(3)) }`, }, { @@ -605,10 +682,8 @@ func TestParameterBindingDecl(t *testing.T) { `func f(x int) int { return <-h(g(2), x) }; func g(int) int; func h(int, int) chan int`, `func _() { f(g(1)) }`, `func _() { - { - var x int = g(1) - <-h(g(2), x) - } + var x int = g(1) + <-h(g(2), x) }`, }, { @@ -675,10 +750,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, c, b) }; func g(int) int`, `func _() { f(g(1), g(2), g(3)) }`, `func _() { - { - var a, b int = g(1), g(2) - print(a, g(3), b) - } + var a, b int = g(1), g(2) + print(a, g(3), b) }`, }, { @@ -698,10 +771,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d int) { print(a, c, b, d) }; func g(int) int; var x, y int`, `func _() { f(g(1), g(2), y, g(3)) }`, `func _() { - { - var a, b int = g(1), g(2) - print(a, y, b, g(3)) - } + var a, b int = g(1), g(2) + print(a, y, b, g(3)) }`, }, { @@ -709,10 +780,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d int) { print(a, c, b, d) }; func g(int) int; var x, y int`, `func _() { f(g(1), y, g(2), g(3)) }`, `func _() { - { - var a, b int = g(1), y - print(a, g(2), b, g(3)) - } + var a, b int = g(1), y + print(a, g(2), b, g(3)) }`, }, { @@ -732,10 +801,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, b, recover().(int), c) }; var x, y, z int`, `func _() { f(x, y, z) }`, `func _() { - { - var c int = z - print(x, y, recover().(int), c) - } + var c int = z + print(x, y, recover().(int), c) }`, }, { @@ -743,10 +810,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c int) { print(a, b, recover().(int), c) }; func g(int) int; var x, y, z int`, `func _() { f(x, y, g(0)) }`, `func _() { - { - var a, b, c int = x, y, g(0) - print(a, b, recover().(int), c) - } + var a, b, c int = x, y, g(0) + print(a, b, recover().(int), c) }`, }, { @@ -754,10 +819,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(a, b, c, d, e int) { print(b, a, c, e, d) }; func g(int) int; var x, y int`, `func _() { f(x, g(1), g(2), y, g(3)) }`, `func _() { - { - var a, b, c, d int = x, g(1), g(2), y - print(b, a, c, g(3), d) - } + var a, b, c, d int = x, g(1), g(2), y + print(b, a, c, g(3), d) }`, }, { @@ -788,10 +851,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(x, y int) { _ = &y }; func g(int) int`, `func _() { f(g(1), g(2)) }`, `func _() { - { - var _, y int = g(1), g(2) - _ = &y - } + var _, y int = g(1), g(2) + _ = &y }`, }, { @@ -802,10 +863,8 @@ func TestSubstitutionPreservesArgumentEffectOrder(t *testing.T) { `func f(x, y int) { _ = x }; func g(int) int; var v int`, `func _() { f(v, g(2)) }`, `func _() { - { - var x, _ int = v, g(2) - _ = x - } + var x, _ int = v, g(2) + _ = x }`, }, { @@ -824,10 +883,8 @@ func TestNamedResultVars(t *testing.T) { `func f() (x int) { return g(x) }; func g(int) int`, `func _() { f() }`, `func _() { - { - var x int - g(x) - } + var x int + g(x) }`, }, { @@ -835,13 +892,11 @@ func TestNamedResultVars(t *testing.T) { `func f(y string) (x int) { return x+x+len(y+y) }`, `func _() { f(".") }`, `func _() { - { - var ( - y string = "." - x int - ) - _ = x + x + len(y+y) - } + var ( + y string = "." + x int + ) + _ = x + x + len(y+y) }`, }, @@ -850,13 +905,11 @@ func TestNamedResultVars(t *testing.T) { `func f(y string) (x string) { return x+y+y }`, `func _() { f(".") }`, `func _() { - { - var ( - y string = "." - x string - ) - _ = x + y + y - } + var ( + y string = "." + x string + ) + _ = x + y + y }`, }, { @@ -864,10 +917,8 @@ func TestNamedResultVars(t *testing.T) { `func f() (x int) { return x+x }`, `func _() { f() }`, `func _() { - { - var x int - _ = x + x - } + var x int + _ = x + x }`, }, { @@ -904,10 +955,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x int16) { y := x; _ = (*int16)(&y) }`, `func _() { f(1) }`, `func _() { - { - y := int16(1) - _ = (*int16)(&y) - } + y := int16(1) + _ = (*int16)(&y) }`, }, { @@ -915,10 +964,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x T) { y := x; _ = (*T)(&y) }; type T struct{}`, `func _() { f(struct{}{}) }`, `func _() { - { - y := T(struct{}{}) - _ = (*T)(&y) - } + y := T(struct{}{}) + _ = (*T)(&y) }`, }, { @@ -926,10 +973,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x T) { y := x; _ = (*T)(&y) }; type T = <-chan int; var ch chan int`, `func _() { f(ch) }`, `func _() { - { - y := T(ch) - _ = (*T)(&y) - } + y := T(ch) + _ = (*T)(&y) }`, }, { @@ -937,10 +982,8 @@ func TestSubstitutionPreservesParameterType(t *testing.T) { `func f(x *int) { y := x; _ = (**int)(&y) }`, `func _() { f(nil) }`, `func _() { - { - y := (*int)(nil) - _ = (**int)(&y) - } + y := (*int)(nil) + _ = (**int)(&y) }`, }, { diff --git a/internal/refactor/inline/testdata/import-shadow.txtar b/internal/refactor/inline/testdata/import-shadow.txtar index 5d4f9243c18..4188a52375d 100644 --- a/internal/refactor/inline/testdata/import-shadow.txtar +++ b/internal/refactor/inline/testdata/import-shadow.txtar @@ -87,10 +87,10 @@ import ( var x b.T func A(b int) { - { - b0.One() - b0.Two() - } //@ inline(re"F", fresult) + + b0.One() + b0.Two() + //@ inline(re"F", fresult) } -- d/d.go -- diff --git a/internal/refactor/inline/testdata/method.txtar b/internal/refactor/inline/testdata/method.txtar index fde9f366c12..b141b09d707 100644 --- a/internal/refactor/inline/testdata/method.txtar +++ b/internal/refactor/inline/testdata/method.txtar @@ -104,10 +104,10 @@ func (T) h() int { return 1 } func _() { var ptr *T - { - var _ T = *ptr - _ = 1 - } //@ inline(re"h", h) + + var _ T = *ptr + _ = 1 + //@ inline(re"h", h) } -- a/i.go -- diff --git a/internal/refactor/inline/testdata/multistmt-body.txtar b/internal/refactor/inline/testdata/multistmt-body.txtar index 1dc39c5c3b9..6bd0108e1fe 100644 --- a/internal/refactor/inline/testdata/multistmt-body.txtar +++ b/internal/refactor/inline/testdata/multistmt-body.txtar @@ -54,10 +54,10 @@ package a func _() { a := 1 - { - z := 1 - print(a + 2 + z) - } //@ inline(re"f", out2) + + z := 1 + print(a + 2 + z) + //@ inline(re"f", out2) } -- a/a3.go -- diff --git a/internal/refactor/inline/testdata/tailcall.txtar b/internal/refactor/inline/testdata/tailcall.txtar index 64f5f9735a0..53b6de367dd 100644 --- a/internal/refactor/inline/testdata/tailcall.txtar +++ b/internal/refactor/inline/testdata/tailcall.txtar @@ -36,19 +36,19 @@ start: package a func _() int { - { - total := 0 - start: - for i := 1; i <= 2; i++ { - total += i - if i == 6 { - goto start - } else if i == 7 { - return -1 - } + + total := 0 +start: + for i := 1; i <= 2; i++ { + total += i + if i == 6 { + goto start + } else if i == 7 { + return -1 } - return total - } //@ inline(re"sum", sum) + } + return total + //@ inline(re"sum", sum) } func sum(lo, hi int) int { diff --git a/internal/refactor/inline/util.go b/internal/refactor/inline/util.go index 82be0fb8ac6..6d8d3fac08f 100644 --- a/internal/refactor/inline/util.go +++ b/internal/refactor/inline/util.go @@ -90,3 +90,16 @@ func funcHasTypeParams(decl *ast.FuncDecl) bool { } return false } + +// intersects reports whether the maps' key sets intersect. +func intersects[K comparable, T1, T2 any](x map[K]T1, y map[K]T2) bool { + if len(x) > len(y) { + return intersects(y, x) + } + for k := range x { + if _, ok := y[k]; ok { + return true + } + } + return false +}