Skip to content

Commit

Permalink
internal/refactor/inline: elide redundant braces
Browse files Browse the repository at this point in the history
When replacing a CallExpr beneath an ExprStmt,
a reduction strategy may return a BlockStmt containing
zero or more statements.

This change eliminates the braces for the block when
it is safe to do so, in other words when these
three conditions are met:

(a) the parent of the ExprStmt is an unrestricted
    statement context e.g. a block or the body
    of a case of a switch or select, but not, say
    "if f(); cond {".
(b) there are no forward gotos in the caller that
    may jump across a declaration. (Currently
    we check for any control labels at all in the
    caller.)
(c) there are no conflicts between names declared
    in the callee block and in the caller block.

Plus tests.

Also, a fix and test for a latent bug allowing
reduction in a restricted "if stmt; expr" context.

Fixes golang/go#63259

Change-Id: I558c75d8306dfd0679768cb4b3dbf05f14b23c39
Reviewed-on: https://go-review.googlesource.com/c/tools/+/532099
LUCI-TryBot-Result: Go LUCI <[email protected]>
Auto-Submit: Alan Donovan <[email protected]>
Reviewed-by: Robert Findley <[email protected]>
  • Loading branch information
adonovan authored and gopherbot committed Oct 3, 2023
1 parent ca34416 commit 586b21a
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 126 deletions.
176 changes: 158 additions & 18 deletions internal/refactor/inline/inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,40 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte,
res.new = &ast.ParenExpr{X: res.new.(ast.Expr)}
}

// Some reduction strategies return a new block holding the
// callee's statements. The block's braces may be elided when
// there is no conflict between names declared in the block
// with those declared by the parent block, and no risk of
// a caller's goto jumping forward across a declaration.
//
// This elision is only safe when the ExprStmt is beneath a
// BlockStmt, CaseClause.Body, or CommClause.Body;
// (see "statement theory").
elideBraces := false
if newBlock, ok := res.new.(*ast.BlockStmt); ok {
parent := caller.path[nodeIndex(caller.path, res.old)+1]
var body []ast.Stmt
switch parent := parent.(type) {
case *ast.BlockStmt:
body = parent.List
case *ast.CommClause:
body = parent.Body
case *ast.CaseClause:
body = parent.Body
}
if body != nil {
if len(callerLabels(caller.path)) > 0 {
// TODO(adonovan): be more precise and reject
// only forward gotos across the inlined block.
logf("keeping block braces: caller uses control labels")
} else if intersects(declares(newBlock.List), declares(body)) {
logf("keeping block braces: avoids name conflict")
} else {
elideBraces = true
}
}
}

// Don't call replaceNode(caller.File, res.old, res.new)
// as it mutates the caller's syntax tree.
// Instead, splice the file, replacing the extent of the "old"
Expand All @@ -124,9 +158,16 @@ func Inline(logf func(string, ...any), caller *Caller, callee *Callee) ([]byte,
// Precise comment handling would make this a
// non-issue. Formatting wouldn't really need a
// FileSet at all.
mark := out.Len()
if err := format.Node(&out, caller.Fset, res.new); err != nil {
return nil, err
}
if elideBraces {
// Overwrite unnecessary {...} braces with spaces.
// TODO(adonovan): less hacky solution.
out.Bytes()[mark] = ' '
out.Bytes()[out.Len()-1] = ' '
}
out.Write(caller.Content[end:])
const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors
f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode)
Expand Down Expand Up @@ -630,7 +671,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
logf("strategy: reduce call to empty body")

// Evaluate the arguments for effects and delete the call entirely.
stmt := callStmt(caller.path) // cannot fail
stmt := callStmt(caller.path, false) // cannot fail
res.old = stmt
if nargs := len(remainingArgs); nargs > 0 {
// Emit "_, _ = args" to discard results.
Expand Down Expand Up @@ -862,9 +903,10 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
// - there is no label conflict between caller and callee
// - all parameters and result vars can be eliminated
// or replaced by a binding decl,
// - caller ExprStmt is in unrestricted statement context.
//
// If there is only a single statement, the braces are omitted.
if stmt := callStmt(caller.path); stmt != nil &&
if stmt := callStmt(caller.path, true); stmt != nil &&
(!needBindingDecl || bindingDeclStmt != nil) &&
!callee.HasDefer &&
!hasLabelConflict(caller.path, callee.Labels) &&
Expand All @@ -876,7 +918,7 @@ func inline(logf func(string, ...any), caller *Caller, callee *gobCallee) (*resu
if needBindingDecl {
body.List = prepend(bindingDeclStmt, body.List...)
}
if len(body.List) == 1 {
if len(body.List) == 1 { // FIXME do this opt later
repl = body.List[0] // singleton: omit braces
}
res.old = stmt
Expand Down Expand Up @@ -2018,30 +2060,41 @@ func callContext(callPath []ast.Node) ast.Node {
// enclosing the call (specified as a PathEnclosingInterval)
// intersects with the set of callee labels.
func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool {
labels := callerLabels(callPath)
for _, label := range calleeLabels {
if labels[label] {
return true // conflict
}
}
return false
}

// callerLabels returns the set of control labels in the function (if
// any) enclosing the call (specified as a PathEnclosingInterval).
func callerLabels(callPath []ast.Node) map[string]bool {
var callerBody *ast.BlockStmt
switch f := callerFunc(callPath).(type) {
case *ast.FuncDecl:
callerBody = f.Body
case *ast.FuncLit:
callerBody = f.Body
}
conflict := false
var labels map[string]bool
if callerBody != nil {
ast.Inspect(callerBody, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit:
return false // prune traversal
case *ast.LabeledStmt:
for _, label := range calleeLabels {
if label == n.Label.Name {
conflict = true
}
if labels == nil {
labels = make(map[string]bool)
}
labels[n.Label.Name] = true
}
return true
})
}
return conflict
return labels
}

// callerFunc returns the innermost Func{Decl,Lit} node enclosing the
Expand All @@ -2059,11 +2112,64 @@ func callerFunc(callPath []ast.Node) ast.Node {
// callStmt reports whether the function call (specified
// as a PathEnclosingInterval) appears within an ExprStmt,
// and returns it if so.
func callStmt(callPath []ast.Node) *ast.ExprStmt {
stmt, _ := callContext(callPath).(*ast.ExprStmt)
//
// If unrestricted, callStmt returns nil if the ExprStmt f() appears
// in a restricted context (such as "if f(); cond {") where it cannot
// be replaced by an arbitrary statement. (See "statement theory".)
func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt {
stmt, ok := callContext(callPath).(*ast.ExprStmt)
if ok && unrestricted {
switch callPath[nodeIndex(callPath, stmt)+1].(type) {
case *ast.LabeledStmt,
*ast.BlockStmt,
*ast.CaseClause,
*ast.CommClause:
// unrestricted
default:
// TODO(adonovan): handle restricted
// XYZStmt.Init contexts (but not ForStmt.Post)
// by creating a block around the if/for/switch:
// "if f(); cond {" -> "{ stmts; if cond {"

return nil // restricted
}
}
return stmt
}

// Statement theory
//
// These are all the places a statement may appear in the AST:
//
// LabeledStmt.Stmt Stmt -- any
// BlockStmt.List []Stmt -- any (but see switch/select)
// IfStmt.Init Stmt? -- simple
// IfStmt.Body BlockStmt
// IfStmt.Else Stmt? -- IfStmt or BlockStmt
// CaseClause.Body []Stmt -- any
// SwitchStmt.Init Stmt? -- simple
// SwitchStmt.Body BlockStmt -- CaseClauses only
// TypeSwitchStmt.Init Stmt? -- simple
// TypeSwitchStmt.Assign Stmt -- AssignStmt(TypeAssertExpr) or ExprStmt(TypeAssertExpr)
// TypeSwitchStmt.Body BlockStmt -- CaseClauses only
// CommClause.Comm Stmt? -- SendStmt or ExprStmt(UnaryExpr) or AssignStmt(UnaryExpr)
// CommClause.Body []Stmt -- any
// SelectStmt.Body BlockStmt -- CommClauses only
// ForStmt.Init Stmt? -- simple
// ForStmt.Post Stmt? -- simple
// ForStmt.Body BlockStmt
// RangeStmt.Body BlockStmt
//
// simple = AssignStmt | SendStmt | IncDecStmt | ExprStmt.
//
// A BlockStmt cannot replace an ExprStmt in
// {If,Switch,TypeSwitch}Stmt.Init or ForStmt.Post.
// That is allowed only within:
// LabeledStmt.Stmt Stmt
// BlockStmt.List []Stmt
// CaseClause.Body []Stmt
// CommClause.Body []Stmt

// replaceNode performs a destructive update of the tree rooted at
// root, replacing each occurrence of "from" with "to". If to is nil and
// the element is within a slice, the slice element is removed.
Expand Down Expand Up @@ -2372,13 +2478,7 @@ func consistentOffsets(caller *Caller) bool {
// ancestor of the CallExpr identified by its PathEnclosingInterval).
func needsParens(callPath []ast.Node, old, new ast.Node) bool {
// Find enclosing old node and its parent.
// TODO(adonovan): Use index[ast.Node]() in go1.20.
i := -1
for i = range callPath {
if callPath[i] == old {
break
}
}
i := nodeIndex(callPath, old)
if i == -1 {
panic("not found")
}
Expand Down Expand Up @@ -2439,3 +2539,43 @@ func needsParens(callPath []ast.Node, old, new ast.Node) bool {
}
return false
}

func nodeIndex(nodes []ast.Node, n ast.Node) int {
// TODO(adonovan): Use index[ast.Node]() in go1.20.
for i, node := range nodes {
if node == n {
return i
}
}
return -1
}

// declares returns the set of lexical names declared by a
// sequence of statements from the same block, excluding sub-blocks.
// (Lexical names do not include control labels.)
func declares(stmts []ast.Stmt) map[string]bool {
names := make(map[string]bool)
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case *ast.DeclStmt:
for _, spec := range stmt.Decl.(*ast.GenDecl).Specs {
switch spec := spec.(type) {
case *ast.ValueSpec:
for _, id := range spec.Names {
names[id.Name] = true
}
case *ast.TypeSpec:
names[spec.Name.Name] = true
}
}

case *ast.AssignStmt:
if stmt.Tok == token.DEFINE {
for _, lhs := range stmt.Lhs {
names[lhs.(*ast.Ident).Name] = true
}
}
}
}
return names
}
Loading

0 comments on commit 586b21a

Please sign in to comment.