diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java index 41cc9a5..ab216e8 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java @@ -48,8 +48,11 @@ public static void main(String[] args) throws IOException { var irBuilder = new CymbolIRBuilder(); astRoot.accept(irBuilder); - for(var block : irBuilder.prog.blockList) { - var cfg = irBuilder.getCFG(List.of(block)); + + irBuilder.prog.optimizeBasicBlock(); + + for(var funBlock : irBuilder.prog.blockList) { + var cfg = irBuilder.getCFG(funBlock); logger.info("CFG:\n" + cfg.toString()); } diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/Prog.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/Prog.java index 130a540..46e4fd2 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/Prog.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/Prog.java @@ -9,10 +9,7 @@ import org.teachfx.antlr4.ep20.ir.stmt.Label; import org.teachfx.antlr4.ep20.pass.cfg.LinearIRBlock; -import java.util.ArrayList; -import java.util.LinkedList; -import java.util.List; -import java.util.Objects; +import java.util.*; public class Prog extends IRNode { @@ -33,14 +30,14 @@ public S accept(IRVisitor visitor){ public void addBlock(LinearIRBlock linearIRBlock) { blockList.add(linearIRBlock); } + protected TreeSet needRemovedBlocks = new TreeSet<>(); - private void linearInstrsImpl(@NotNull LinearIRBlock linearIRBlock) { - // Add all instr from non-empty block - if (!linearIRBlock.getStmts().isEmpty()) { - instrs.addAll(linearIRBlock.getStmts()); - } else { + private void optimizeEmptyBlock(@NotNull LinearIRBlock linearIRBlock) { + // replace empty block within non-empty first successor + if (linearIRBlock.getStmts().isEmpty()){ // Drop empty block if (linearIRBlock.getSuccessors().isEmpty()) { + needRemovedBlocks.add(linearIRBlock); return; } // Auto-fill next block for jmp/cjmp @@ -48,16 +45,60 @@ private void linearInstrsImpl(@NotNull LinearIRBlock linearIRBlock) { for (var ref : linearIRBlock.getJmpRefMap()){ if (ref instanceof JMP jmp) { jmp.setNext(nextBlock); + } else if (ref instanceof CJMP cjmp) { cjmp.setElseBlock(nextBlock); } } + linearIRBlock.getPredecessors().forEach(prev -> { + prev.removeSuccessor(linearIRBlock); + prev.getSuccessors().add(nextBlock); + }); + + needRemovedBlocks.add(linearIRBlock); } // recursive call for(var successor : linearIRBlock.getSuccessors()){ - linearInstrsImpl(successor); + optimizeEmptyBlock(successor); + } + } + + private void insertLabelForBlock(LinearIRBlock startBlock) { + for (var stmt : startBlock.getStmts()) { + if (stmt instanceof Label) { + break; + } + + startBlock.insertStmt(new Label(startBlock.getScope(), startBlock.getOrd()),0); + break; // only insert one label for each block which is not func-entry block. + } + + for (var successor : startBlock.getSuccessors()) { + insertLabelForBlock(successor); + } + } + protected void buildInstrs(LinearIRBlock block) { + instrs.addAll(block.getStmts()); + + for (var successor : block.getSuccessors()) { + buildInstrs(successor); + } + + } + + public void optimizeBasicBlock() { + for(var func : blockList) { + optimizeEmptyBlock(func); + } + + for( var emptyBlock : needRemovedBlocks) { + emptyBlock.getPredecessors().forEach(p -> p.removeSuccessor(emptyBlock)); + } + + for(var func : blockList) { + insertLabelForBlock(func); } } @@ -67,8 +108,8 @@ public List linearInstrs() { return truncateInstrList; } - for(var block : blockList) { - linearInstrsImpl(block); + for(var func : blockList) { + buildInstrs(func); } IRNode prev; diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Label.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Label.java index 423a513..bcfbc84 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Label.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Label.java @@ -31,9 +31,7 @@ public Label(Scope scope,Integer ord) { } public Label(Scope scope) { - this.scope = scope; - this.seq = scope.getLabelSeq(); - this.rawLabel = null; + this(scope, scope.getLabelSeq()); } public void setRawLabel(String rawLabel) { diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Stmt.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Stmt.java index e3f31e5..e60e12a 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Stmt.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/ir/stmt/Stmt.java @@ -15,4 +15,5 @@ public enum StmtType { public abstract S accept(IRVisitor visitor); public abstract StmtType getStmtType(); + } diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/BasicBlock.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/BasicBlock.java index 4447db3..df1f689 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/BasicBlock.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/BasicBlock.java @@ -4,6 +4,9 @@ import org.jetbrains.annotations.NotNull; import org.teachfx.antlr4.ep20.ir.IRNode; import org.teachfx.antlr4.ep20.ir.expr.Operand; +import org.teachfx.antlr4.ep20.ir.stmt.CJMP; +import org.teachfx.antlr4.ep20.ir.stmt.FuncEntryLabel; +import org.teachfx.antlr4.ep20.ir.stmt.JMP; import org.teachfx.antlr4.ep20.ir.stmt.Label; import org.teachfx.antlr4.ep20.utils.Kind; @@ -25,17 +28,17 @@ public class BasicBlock implements Comparable>, public Set liveOut; protected Label label; - public BasicBlock(Kind kind, List> codes, Label label) { + public BasicBlock(Kind kind, List> codes,Label label,int ord) { this.codes = codes; this.label = label; - this.id = label.getSeq(); + this.id = ord; this.kind = kind; } @NotNull @Contract("_ -> new") - public static BasicBlock buildFromLinearBlock(@NotNull LinearIRBlock block) { - return new BasicBlock(block.getKind(), block.getStmts().stream().map(Loc::new).toList(), block.getLabel()); + public static BasicBlock buildFromLinearBlock(@NotNull LinearIRBlock block,List> cachedNodes) { + return new BasicBlock(block.getKind(), block.getStmts().stream().map(Loc::new).toList(),block.getLabel(),block.getOrd()); } @Override @@ -93,6 +96,11 @@ public List> allSeq() { public List> dropLabelSeq() { if (codes.size() <= 1) return codes; + + if (codes.get(0).instr instanceof FuncEntryLabel) { + return codes.subList(0, codes.size()); + } + return codes.subList(1, codes.size()); } diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java index fa1269c..5ad8437 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java @@ -22,15 +22,14 @@ public class CFG implements Iterable> { // index: 第几号节点 -> : <前驱节点的集合,后继节点的集合> private final List, Set>> links; - public CFG(Map> nodes, List> edges) { + public CFG(List> nodes, List> edges) { // Generate init - var lastOrd = Integer.max(nodes.keySet().stream().max(Integer::compareTo).get(), nodes.size()) + 1; - - this.nodes = new LinkedList<>(nodes.values()); + var maxOrd = nodes.stream().max(BasicBlock::compareTo).map(BasicBlock::getId).get() + 1; + this.nodes = nodes; this.edges = edges; links = new ArrayList<>(); - for (var i = 0; i < lastOrd; i++) { + for (var i = 0; i < maxOrd; i++) { links.add(Pair.of(new TreeSet<>(), new TreeSet<>())); } @@ -78,7 +77,7 @@ public String toString() { var graphRenderBuffer = new StringBuilder("graph TD\n"); AtomicInteger i = new AtomicInteger(); - for (var node : nodes.stream().sorted((b1,b2)-> b2.id - b1.id).toList()) { + for (var node : nodes.stream().sorted((b1,b2) -> b2.id - b1.id).toList()) { graphRenderBuffer.append("subgraph ").append(node.getOrdLabel()).append("\n"); node.dropLabelSeq().stream().map(x -> x.instr.toString()).map(x -> "Q" + (i.getAndIncrement()) + "[\"" + x + ";\"]\n").forEach(graphRenderBuffer::append); @@ -87,7 +86,6 @@ public String toString() { } for (var edge : edges) { - graphRenderBuffer.append("L").append(edge.getLeft()).append(" --> ").append("L").append(edge.getRight()).append("\n"); } diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFGBuilder.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFGBuilder.java index 00d8848..f949619 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFGBuilder.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFGBuilder.java @@ -2,29 +2,62 @@ import org.apache.commons.lang3.tuple.Pair; import org.teachfx.antlr4.ep20.ir.IRNode; +import org.teachfx.antlr4.ep20.ir.JMPInstr; +import org.teachfx.antlr4.ep20.ir.stmt.CJMP; +import org.teachfx.antlr4.ep20.ir.stmt.JMP; import java.util.*; public class CFGBuilder { private static final Set cachedEdgeLinks = new HashSet<>(); private final CFG cfg; - private final Map> basicBlocks; + private final List> basicBlocks; private final List> edges; - public CFGBuilder(List blockList) { - basicBlocks = new HashMap<>(); + public CFGBuilder(LinearIRBlock startBlock) { + basicBlocks = new ArrayList<>(); edges = new ArrayList<>(); - for (var funcLabelBlock : blockList) { - build(funcLabelBlock); - } + + var cachedEdgeLink = new HashSet(); + + build(startBlock,cachedEdgeLink); + cfg = new CFG<>(basicBlocks, edges); } - private void build(LinearIRBlock block) { - basicBlocks.put(block.getLabel().getSeq(), BasicBlock.buildFromLinearBlock(block)); + private void build(LinearIRBlock block,Set cachedEdgeLinks) { + var currentBlock = BasicBlock.buildFromLinearBlock(block,basicBlocks); + basicBlocks.add(currentBlock); + var lastInstr = block.getStmts().get(block.getStmts().size() - 1); + var currentOrd = block.getOrd(); + if (lastInstr instanceof JMP jmp) { + var destOrd = jmp.getNext().getOrd(); + var key = currentOrd + "-" + destOrd; + if (!cachedEdgeLinks.contains(key)) { + cachedEdgeLinks.add(key); + edges.add(Pair.of(currentOrd, destOrd)); + } + } else if (lastInstr instanceof CJMP cjmp) { + var elseOrd = cjmp.getElseBlock().getOrd(); + var key = currentOrd + "-" + elseOrd; + if (!cachedEdgeLinks.contains(key)) { + cachedEdgeLinks.add(key); + edges.add(Pair.of(currentOrd, elseOrd)); + } + } + + for (var successor : block.getSuccessors()){ + var key = currentOrd + "-" + successor.getOrd(); + if (!cachedEdgeLinks.contains(key)) { + cachedEdgeLinks.add(key); + edges.add(Pair.of(currentOrd, successor.getOrd())); + } + build(successor,cachedEdgeLinks); + } } + public CFG getCFG() { return cfg; } diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/LinearIRBlock.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/LinearIRBlock.java index e019544..db0b52d 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/LinearIRBlock.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/LinearIRBlock.java @@ -26,14 +26,23 @@ public class LinearIRBlock implements Comparable { private List jmpRefMap = new ArrayList<>(); // Constructor - public LinearIRBlock() { + public LinearIRBlock(Scope scope) { stmts = new ArrayList<>(); successors = new ArrayList<>(); predecessors = new ArrayList<>(); ord = LABEL_SEQ++; + this.scope = scope; logger.info(ord); } + public LinearIRBlock() { + stmts = new ArrayList<>(); + successors = new ArrayList<>(); + predecessors = new ArrayList<>(); + ord = LABEL_SEQ++; + this.scope = null; + logger.info(ord); + } public static boolean isBasicBlock(Stmt stmt) { return !(stmt instanceof CJMP) && !(stmt instanceof JMP); } @@ -51,6 +60,10 @@ public void addStmt(IRNode stmt) { updateKindByLastInstr(stmt); } + public void insertStmt(IRNode stmt,int idx) { + stmts.add(idx, stmt); + } + private void updateKindByLastInstr(IRNode stmt) { if (stmt instanceof CJMP) { kind = Kind.END_BY_CJMP; @@ -152,14 +165,10 @@ public Label getLabel() { var firstInstr = stmts.get(0); if (firstInstr instanceof Label label) { - return label; - } else { - stmts.add(0, new Label(scope, getOrd())); - firstInstr = stmts.get(0); } - return (Label) firstInstr; + return null; } public Optional> getJumpEntries() { @@ -201,12 +210,16 @@ public int hashCode() { */ public void mergeBlock(LinearIRBlock otherBlock) { stmts.addAll(otherBlock.getStmts()); + updateKindByLastInstr(stmts.get(stmts.size() - 1)); + setSuccessors(otherBlock.getSuccessors()); + for(var s : otherBlock.getSuccessors()) { s.predecessors.remove(otherBlock); s.predecessors.add(this); } + for (var jmp : otherBlock.getJmpRefMap()) { if (jmp instanceof JMP jmpInstr) { jmpInstr.setNext(this); @@ -217,4 +230,8 @@ public void mergeBlock(LinearIRBlock otherBlock) { } } } + + public void removeSuccessor(LinearIRBlock block) { + successors.remove(block); + } } \ No newline at end of file diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/codegen/CymbolAssembler.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/codegen/CymbolAssembler.java index 0b655ff..d95c7e7 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/codegen/CymbolAssembler.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/codegen/CymbolAssembler.java @@ -75,6 +75,8 @@ public Void visit(Label label) { if (indents > 0) { indents--; } if (label instanceof FuncEntryLabel){ + // reset indent + indents = 0; emit("%s".formatted(label.toSource())); } else { emit("%s:".formatted(label.toSource())); diff --git a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/ir/CymbolIRBuilder.java b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/ir/CymbolIRBuilder.java index 7d13b1d..cf6f4fb 100644 --- a/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/ir/CymbolIRBuilder.java +++ b/ep20/src/main/java/org/teachfx/antlr4/ep20/pass/ir/CymbolIRBuilder.java @@ -230,9 +230,9 @@ public Void visit(ReturnStmtNode returnStmtNode) { @Override public Void visit(WhileStmtNode whileStmtNode) { curNode = whileStmtNode; - var condBlock = new LinearIRBlock(); - var doBlock = new LinearIRBlock(); - var endBlock = new LinearIRBlock(); + var condBlock = new LinearIRBlock(currentBlock.getScope()); + var doBlock = new LinearIRBlock(currentBlock.getScope()); + var endBlock = new LinearIRBlock(currentBlock.getScope()); jump(condBlock); @@ -263,15 +263,15 @@ public Void visit(IfStmtNode ifStmtNode) { ifStmtNode.getCondExpr().accept(this); var cond = peekEvalOperand(); - var thenBlock = new LinearIRBlock(); - var endBlock = new LinearIRBlock(); + var thenBlock = new LinearIRBlock(currentBlock.getScope()); + var endBlock = new LinearIRBlock(currentBlock.getScope()); if (ifStmtNode.getElseBlock().isEmpty()) { jumpIf(cond,thenBlock,endBlock); setCurrentBlock(thenBlock); ifStmtNode.getThenBlock().accept(this); }else { - var elseBlock = new LinearIRBlock(); + var elseBlock = new LinearIRBlock(currentBlock.getScope()); jumpIf(cond,thenBlock,elseBlock); setCurrentBlock(thenBlock); ifStmtNode.getThenBlock().accept(this); @@ -426,20 +426,8 @@ protected VarSlot peekEvalOperand() { return evalExprStack.peek(); } - public CFG getCFG(List blocks) { - for (var func : blocks){ - insertBlockLabel(func); - } - - var cfgBuilder = new CFGBuilder(blocks); + public CFG getCFG(LinearIRBlock startBlocks) { + var cfgBuilder = new CFGBuilder(startBlocks); return cfgBuilder.getCFG(); } - - public void insertBlockLabel(LinearIRBlock startBlock) { - startBlock.getLabel(); - - for (var successor : startBlock.getSuccessors()) { - insertBlockLabel(successor); - } - } }