Skip to content

Commit

Permalink
Optimize basic blocks and build CFG
Browse files Browse the repository at this point in the history
- Optimize empty blocks by merging successors
- Insert labels for blocks
- Build instructions by DFS traversal
- Build CFG with cached edge links
- Reset assembler indent for function labels
- Create blocks with scope for IR building
  • Loading branch information
whtoo committed Nov 30, 2023
1 parent 90a9db9 commit 2eb99a9
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 62 deletions.
7 changes: 5 additions & 2 deletions ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ public static void main(String[] args) throws IOException {
var irBuilder = new CymbolIRBuilder();

astRoot.accept(irBuilder);
for(var block : irBuilder.prog.blockList) {
var cfg = irBuilder.getCFG(List.of(block));

irBuilder.prog.optimizeBasicBlock();

for(var funBlock : irBuilder.prog.blockList) {
var cfg = irBuilder.getCFG(funBlock);
logger.info("CFG:\n" + cfg.toString());
}

Expand Down
65 changes: 53 additions & 12 deletions ep20/src/main/java/org/teachfx/antlr4/ep20/ir/Prog.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import org.teachfx.antlr4.ep20.ir.stmt.Label;
import org.teachfx.antlr4.ep20.pass.cfg.LinearIRBlock;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.*;


public class Prog extends IRNode {
Expand All @@ -33,31 +30,75 @@ public <S,E> S accept(IRVisitor<S,E> visitor){
public void addBlock(LinearIRBlock linearIRBlock) {
blockList.add(linearIRBlock);
}
protected TreeSet<LinearIRBlock> needRemovedBlocks = new TreeSet<>();

private void linearInstrsImpl(@NotNull LinearIRBlock linearIRBlock) {
// Add all instr from non-empty block
if (!linearIRBlock.getStmts().isEmpty()) {
instrs.addAll(linearIRBlock.getStmts());
} else {
private void optimizeEmptyBlock(@NotNull LinearIRBlock linearIRBlock) {
// replace empty block within non-empty first successor
if (linearIRBlock.getStmts().isEmpty()){
// Drop empty block
if (linearIRBlock.getSuccessors().isEmpty()) {
needRemovedBlocks.add(linearIRBlock);
return;
}
// Auto-fill next block for jmp/cjmp
var nextBlock = linearIRBlock.getSuccessors().get(0);
for (var ref : linearIRBlock.getJmpRefMap()){
if (ref instanceof JMP jmp) {
jmp.setNext(nextBlock);

} else if (ref instanceof CJMP cjmp) {
cjmp.setElseBlock(nextBlock);
}
}

linearIRBlock.getPredecessors().forEach(prev -> {
prev.removeSuccessor(linearIRBlock);
prev.getSuccessors().add(nextBlock);
});

needRemovedBlocks.add(linearIRBlock);
}

// recursive call
for(var successor : linearIRBlock.getSuccessors()){
linearInstrsImpl(successor);
optimizeEmptyBlock(successor);
}
}

private void insertLabelForBlock(LinearIRBlock startBlock) {
for (var stmt : startBlock.getStmts()) {
if (stmt instanceof Label) {
break;
}

startBlock.insertStmt(new Label(startBlock.getScope(), startBlock.getOrd()),0);
break; // only insert one label for each block which is not func-entry block.
}

for (var successor : startBlock.getSuccessors()) {
insertLabelForBlock(successor);
}
}
protected void buildInstrs(LinearIRBlock block) {
instrs.addAll(block.getStmts());

for (var successor : block.getSuccessors()) {
buildInstrs(successor);
}

}

public void optimizeBasicBlock() {
for(var func : blockList) {
optimizeEmptyBlock(func);
}

for( var emptyBlock : needRemovedBlocks) {
emptyBlock.getPredecessors().forEach(p -> p.removeSuccessor(emptyBlock));
}

for(var func : blockList) {
insertLabelForBlock(func);
}
}

Expand All @@ -67,8 +108,8 @@ public List<IRNode> linearInstrs() {
return truncateInstrList;
}

for(var block : blockList) {
linearInstrsImpl(block);
for(var func : blockList) {
buildInstrs(func);
}

IRNode prev;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ public Label(Scope scope,Integer ord) {
}

public Label(Scope scope) {
this.scope = scope;
this.seq = scope.getLabelSeq();
this.rawLabel = null;
this(scope, scope.getLabelSeq());
}

public void setRawLabel(String rawLabel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ public enum StmtType {
public abstract <S,E> S accept(IRVisitor<S,E> visitor);

public abstract StmtType getStmtType();

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import org.jetbrains.annotations.NotNull;
import org.teachfx.antlr4.ep20.ir.IRNode;
import org.teachfx.antlr4.ep20.ir.expr.Operand;
import org.teachfx.antlr4.ep20.ir.stmt.CJMP;
import org.teachfx.antlr4.ep20.ir.stmt.FuncEntryLabel;
import org.teachfx.antlr4.ep20.ir.stmt.JMP;
import org.teachfx.antlr4.ep20.ir.stmt.Label;
import org.teachfx.antlr4.ep20.utils.Kind;

Expand All @@ -25,17 +28,17 @@ public class BasicBlock<I extends IRNode> implements Comparable<BasicBlock<I>>,
public Set<Operand> liveOut;
protected Label label;

public BasicBlock(Kind kind, List<Loc<I>> codes, Label label) {
public BasicBlock(Kind kind, List<Loc<I>> codes,Label label,int ord) {
this.codes = codes;
this.label = label;
this.id = label.getSeq();
this.id = ord;
this.kind = kind;
}

@NotNull
@Contract("_ -> new")
public static BasicBlock<IRNode> buildFromLinearBlock(@NotNull LinearIRBlock block) {
return new BasicBlock<IRNode>(block.getKind(), block.getStmts().stream().map(Loc::new).toList(), block.getLabel());
public static BasicBlock<IRNode> buildFromLinearBlock(@NotNull LinearIRBlock block,List<BasicBlock<IRNode>> cachedNodes) {
return new BasicBlock<IRNode>(block.getKind(), block.getStmts().stream().map(Loc::new).toList(),block.getLabel(),block.getOrd());
}

@Override
Expand Down Expand Up @@ -93,6 +96,11 @@ public List<Loc<I>> allSeq() {

public List<Loc<I>> dropLabelSeq() {
if (codes.size() <= 1) return codes;

if (codes.get(0).instr instanceof FuncEntryLabel) {
return codes.subList(0, codes.size());
}

return codes.subList(1, codes.size());
}

Expand Down
12 changes: 5 additions & 7 deletions ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ public class CFG<I extends IRNode> implements Iterable<BasicBlock<I>> {
// index: 第几号节点 -> <prev,successors> : <前驱节点的集合,后继节点的集合>
private final List<Pair<Set<Integer>, Set<Integer>>> links;

public CFG(Map<Integer, BasicBlock<I>> nodes, List<Pair<Integer, Integer>> edges) {
public CFG(List<BasicBlock<I>> nodes, List<Pair<Integer, Integer>> edges) {
// Generate init
var lastOrd = Integer.max(nodes.keySet().stream().max(Integer::compareTo).get(), nodes.size()) + 1;

this.nodes = new LinkedList<>(nodes.values());
var maxOrd = nodes.stream().max(BasicBlock::compareTo).map(BasicBlock::getId).get() + 1;
this.nodes = nodes;
this.edges = edges;

links = new ArrayList<>();
for (var i = 0; i < lastOrd; i++) {
for (var i = 0; i < maxOrd; i++) {
links.add(Pair.of(new TreeSet<>(), new TreeSet<>()));
}

Expand Down Expand Up @@ -78,7 +77,7 @@ public String toString() {
var graphRenderBuffer = new StringBuilder("graph TD\n");
AtomicInteger i = new AtomicInteger();

for (var node : nodes.stream().sorted((b1,b2)-> b2.id - b1.id).toList()) {
for (var node : nodes.stream().sorted((b1,b2) -> b2.id - b1.id).toList()) {

graphRenderBuffer.append("subgraph ").append(node.getOrdLabel()).append("\n");
node.dropLabelSeq().stream().map(x -> x.instr.toString()).map(x -> "Q" + (i.getAndIncrement()) + "[\"" + x + ";\"]\n").forEach(graphRenderBuffer::append);
Expand All @@ -87,7 +86,6 @@ public String toString() {
}

for (var edge : edges) {

graphRenderBuffer.append("L").append(edge.getLeft()).append(" --> ").append("L").append(edge.getRight()).append("\n");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,62 @@

import org.apache.commons.lang3.tuple.Pair;
import org.teachfx.antlr4.ep20.ir.IRNode;
import org.teachfx.antlr4.ep20.ir.JMPInstr;
import org.teachfx.antlr4.ep20.ir.stmt.CJMP;
import org.teachfx.antlr4.ep20.ir.stmt.JMP;

import java.util.*;

public class CFGBuilder {
private static final Set<String> cachedEdgeLinks = new HashSet<>();
private final CFG<IRNode> cfg;
private final Map<Integer, BasicBlock<IRNode>> basicBlocks;
private final List<BasicBlock<IRNode>> basicBlocks;
private final List<Pair<Integer, Integer>> edges;

public CFGBuilder(List<LinearIRBlock> blockList) {
basicBlocks = new HashMap<>();
public CFGBuilder(LinearIRBlock startBlock) {
basicBlocks = new ArrayList<>();
edges = new ArrayList<>();
for (var funcLabelBlock : blockList) {
build(funcLabelBlock);
}

var cachedEdgeLink = new HashSet<String>();

build(startBlock,cachedEdgeLink);

cfg = new CFG<>(basicBlocks, edges);
}

private void build(LinearIRBlock block) {
basicBlocks.put(block.getLabel().getSeq(), BasicBlock.buildFromLinearBlock(block));
private void build(LinearIRBlock block,Set<String> cachedEdgeLinks) {
var currentBlock = BasicBlock.buildFromLinearBlock(block,basicBlocks);
basicBlocks.add(currentBlock);
var lastInstr = block.getStmts().get(block.getStmts().size() - 1);
var currentOrd = block.getOrd();

if (lastInstr instanceof JMP jmp) {
var destOrd = jmp.getNext().getOrd();
var key = currentOrd + "-" + destOrd;
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, destOrd));
}
} else if (lastInstr instanceof CJMP cjmp) {
var elseOrd = cjmp.getElseBlock().getOrd();
var key = currentOrd + "-" + elseOrd;
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, elseOrd));
}
}

for (var successor : block.getSuccessors()){
var key = currentOrd + "-" + successor.getOrd();
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, successor.getOrd()));
}
build(successor,cachedEdgeLinks);
}
}


public CFG<IRNode> getCFG() {
return cfg;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,23 @@ public class LinearIRBlock implements Comparable<LinearIRBlock> {
private List<JMPInstr> jmpRefMap = new ArrayList<>();

// Constructor
public LinearIRBlock() {
public LinearIRBlock(Scope scope) {
stmts = new ArrayList<>();
successors = new ArrayList<>();
predecessors = new ArrayList<>();
ord = LABEL_SEQ++;
this.scope = scope;
logger.info(ord);
}

public LinearIRBlock() {
stmts = new ArrayList<>();
successors = new ArrayList<>();
predecessors = new ArrayList<>();
ord = LABEL_SEQ++;
this.scope = null;
logger.info(ord);
}
public static boolean isBasicBlock(Stmt stmt) {
return !(stmt instanceof CJMP) && !(stmt instanceof JMP);
}
Expand All @@ -51,6 +60,10 @@ public void addStmt(IRNode stmt) {
updateKindByLastInstr(stmt);
}

public void insertStmt(IRNode stmt,int idx) {
stmts.add(idx, stmt);
}

private void updateKindByLastInstr(IRNode stmt) {
if (stmt instanceof CJMP) {
kind = Kind.END_BY_CJMP;
Expand Down Expand Up @@ -152,14 +165,10 @@ public Label getLabel() {

var firstInstr = stmts.get(0);
if (firstInstr instanceof Label label) {

return label;
} else {
stmts.add(0, new Label(scope, getOrd()));
firstInstr = stmts.get(0);
}

return (Label) firstInstr;
return null;
}

public Optional<TreeSet<LinearIRBlock>> getJumpEntries() {
Expand Down Expand Up @@ -201,12 +210,16 @@ public int hashCode() {
*/
public void mergeBlock(LinearIRBlock otherBlock) {
stmts.addAll(otherBlock.getStmts());

updateKindByLastInstr(stmts.get(stmts.size() - 1));

setSuccessors(otherBlock.getSuccessors());

for(var s : otherBlock.getSuccessors()) {
s.predecessors.remove(otherBlock);
s.predecessors.add(this);
}

for (var jmp : otherBlock.getJmpRefMap()) {
if (jmp instanceof JMP jmpInstr) {
jmpInstr.setNext(this);
Expand All @@ -217,4 +230,8 @@ public void mergeBlock(LinearIRBlock otherBlock) {
}
}
}

public void removeSuccessor(LinearIRBlock block) {
successors.remove(block);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public Void visit(Label label) {
if (indents > 0) { indents--; }

if (label instanceof FuncEntryLabel){
// reset indent
indents = 0;
emit("%s".formatted(label.toSource()));
} else {
emit("%s:".formatted(label.toSource()));
Expand Down
Loading

0 comments on commit 2eb99a9

Please sign in to comment.