Skip to content

Commit

Permalink
Remove redundant jumps and merge blocks
Browse files Browse the repository at this point in the history
- Removed redundant jump instructions when jump target matches next block
- Merged blocks when in-degree is 1 and out-degree of predecessor is 1
- Updated CFG edges and nodes accordingly after optimizations
- Added ControlFlowAnalysis optimizer to perform optimizations
  • Loading branch information
whtoo committed Dec 1, 2023
1 parent c93093d commit ed60830
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 93 deletions.
4 changes: 0 additions & 4 deletions ep18/src/main/resources/t.vm
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
load 0
iconst 1
isub
br L1
L1:
ret
.def main: args=0 ,locals=1
iconst 10
store 0
br L4
L4:
load 0
iconst 0
Expand Down Expand Up @@ -38,6 +35,5 @@ L8:
br L4
L6:
iconst 0
br L3
L3:
halt
41 changes: 40 additions & 1 deletion ep20/src/main/java/org/teachfx/antlr4/ep20/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.teachfx.antlr4.ep20.parser.CymbolLexer;
import org.teachfx.antlr4.ep20.parser.CymbolParser;
import org.teachfx.antlr4.ep20.pass.ast.CymbolASTBuilder;
import org.teachfx.antlr4.ep20.pass.cfg.ControlFlowAnalysis;
import org.teachfx.antlr4.ep20.pass.codegen.CymbolAssembler;
import org.teachfx.antlr4.ep20.pass.ir.CymbolIRBuilder;
import org.teachfx.antlr4.ep20.pass.symtab.LocalDefine;

import java.io.*;
import java.util.LinkedList;
import java.util.List;

public class Compiler {
Expand Down Expand Up @@ -50,14 +52,23 @@ public static void main(String[] args) throws IOException {
astRoot.accept(irBuilder);

irBuilder.prog.optimizeBasicBlock();
var cfgOptimizer = new ControlFlowAnalysis<IRNode>();
var cnt = 0;
var codeBuffer = new LinkedList<IRNode>();

for(var funBlock : irBuilder.prog.blockList) {
var cfg = irBuilder.getCFG(funBlock);
saveToEp20Res(cfg.toString(),"origin"+cnt);
cfg.addOptimizer(cfgOptimizer);
cfg.applyOptimizers();
saveToEp20Res(cfg.toString(),"optimized"+cnt);
cnt++;
logger.info("CFG:\n" + cfg.toString());
codeBuffer.addAll(cfg.getIRNodes());
}

var assembler = new CymbolAssembler();
assembler.visit(irBuilder.prog.linearInstrs());
assembler.visit(codeBuffer);
saveToEp18Res(assembler.getAsmInfo());
logger.info("\n%s".formatted(assembler.getAsmInfo()));
}
Expand All @@ -84,4 +95,32 @@ protected static void saveToEp18Res(String buffer) {
logger.error("模块路径不存在!");
}
}

protected static void saveToEp20Res(String buffer,String suffix) {
String modulePath = "./src/main/resources"; // 替换 "my-module" 为你的模块名称
File moduleDirectory = new File(modulePath);
logger.info("file path %s".formatted(moduleDirectory.getAbsolutePath()));
if (moduleDirectory.exists()) {
logger.info("模块路径:" + moduleDirectory.getAbsolutePath());
var filePath = modulePath+"/graph_%s.md".formatted(suffix);
File file = new File(filePath);
try (var outputStream = new FileOutputStream(file)) {
if (!file.exists()) {
file.createNewFile();
}
var mdTmeplate = """
```mermaid
%s
```
""".formatted(buffer);
outputStream.write(mdTmeplate.getBytes());

} catch (IOException e) {
throw new RuntimeException(e);
}

} else {
logger.error("模块路径不存在!");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.teachfx.antlr4.ep20.ir.IRNode;
import org.teachfx.antlr4.ep20.ir.JMPInstr;
import org.teachfx.antlr4.ep20.ir.expr.Operand;
import org.teachfx.antlr4.ep20.ir.stmt.CJMP;
import org.teachfx.antlr4.ep20.ir.stmt.FuncEntryLabel;
import org.teachfx.antlr4.ep20.ir.stmt.JMP;
import org.teachfx.antlr4.ep20.ir.stmt.Label;
import org.teachfx.antlr4.ep20.utils.Kind;
import org.teachfx.antlr4.ep20.utils.StreamUtils;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.stream.Stream;

public class BasicBlock<I extends IRNode> implements Comparable<BasicBlock<I>>, Iterable<Loc<I>> {

Expand All @@ -29,7 +29,7 @@ public class BasicBlock<I extends IRNode> implements Comparable<BasicBlock<I>>,
protected Label label;

public BasicBlock(Kind kind, List<Loc<I>> codes,Label label,int ord) {
this.codes = codes;
this.codes = new ArrayList<>(codes);
this.label = label;
this.id = ord;
this.kind = kind;
Expand Down Expand Up @@ -110,9 +110,22 @@ public I getLastInstr() {

public void mergeNearBlock(BasicBlock<I> nextBlock) {
/// remove last jump instr
if (getLastInstr() instanceof JMPInstr) {
codes.remove(codes.size() - 1);
}

/// merge instr and update kind to use merge nextblock's kind
codes.addAll(nextBlock.dropLabelSeq());
kind = nextBlock.kind;
}

public void removeLastInstr() {
codes.remove(codes.size() - 1);
/// merge instr
codes.addAll(nextBlock.codes);
kind = Kind.CONTINUOUS;
}


public Stream<I> getIRNodes() {
return StreamUtils.flatMap(codes.stream(), Loc::getInstr);
}
}
97 changes: 86 additions & 11 deletions ep20/src/main/java/org/teachfx/antlr4/ep20/pass/cfg/CFG.java
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
package org.teachfx.antlr4.ep20.pass.cfg;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.teachfx.antlr4.ep20.ir.IRNode;
import org.teachfx.antlr4.ep20.utils.StreamUtils;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

// TODO: visualize cfg
/*
One graph bind to a function.
*/
public class CFG<I extends IRNode> implements Iterable<BasicBlock<I>> {
private final static Logger logger = LogManager.getLogger(CFG.class);
// index: 第几号节点 -> BasicBlock<I> : 第几号节点对应的BasicBlock<I>节点
public final List<BasicBlock<I>> nodes;
// <from,to> : <起始节点,终止节点>
public final List<Pair<Integer, Integer>> edges;
// <from,to,weight> : <起始节点,终止节点,权重>
public final List<Triple<Integer, Integer,Integer>> edges;
// index: 第几号节点 -> <prev,successors> : <前驱节点的集合,后继节点的集合>
private final List<Pair<Set<Integer>, Set<Integer>>> links;
private final List<IFlowOptimizer<I>> optimizers = new ArrayList<>();

public CFG(List<BasicBlock<I>> nodes, List<Pair<Integer, Integer>> edges) {
public CFG(List<BasicBlock<I>> nodes, List<Triple<Integer, Integer,Integer>> edges) {
// Generate init
var maxOrd = nodes.stream().max(BasicBlock::compareTo).map(BasicBlock::getId).get() + 1;
this.nodes = nodes;
Expand All @@ -35,7 +38,7 @@ public CFG(List<BasicBlock<I>> nodes, List<Pair<Integer, Integer>> edges) {

for (var edge : edges) {
var u = edge.getLeft();
var v = edge.getRight();
var v = edge.getMiddle();
links.get(u).getRight().add(v);
links.get(v).getLeft().add(u);
}
Expand All @@ -45,18 +48,40 @@ public BasicBlock<I> getBlock(int id) {
return nodes.get(id);
}

public Set<Integer> getPrev(int id) {
/**
* 获取前驱节点集合,也就是该节点的前驱节点集合。
* @param id 节点的id
* @return 前驱节点集合。
*/
public Set<Integer> getFrontier(int id) {
return links.get(id).getLeft();
}

/**
* 获取后继节点集合,也就是该节点的后继节点集合。
* @param id 节点的id
* @return 后继节点集合。
*/
public Set<Integer> getSucceed(int id) {
return links.get(id).getRight();
}


// 这个方法是用来获取节点的入度的,入度就是该节点的前驱节点的数量。
public int getInDegree(int id) {
return links.get(id).getLeft().size();
}


/*
@desc 这个方法是用来获取节点的入度的,入度就是该节点的前驱节点的数量。
@param key 节点的id
*/
public Stream<Triple<Integer,Integer,Integer>> getInEdges(int key) {
return edges.stream().filter(edge -> edge.getMiddle() == key);
}

// 这个方法是用来获取节点的出度的,出度就是该节点的后继节点的数量。
public int getOutDegree(int id) {
return links.get(id).getRight().size();
}
Expand All @@ -67,11 +92,9 @@ public Iterator<BasicBlock<I>> iterator() {
return nodes.iterator();
}

public List<I> simplifyIRInstrs() {

return null;
}

// 写个注释
// 这个方法是用来生成dot文件的,这个文件可以用来生成图形的。
@Override
public String toString() {
var graphRenderBuffer = new StringBuilder("graph TD\n");
Expand All @@ -86,11 +109,63 @@ public String toString() {
}

for (var edge : edges) {
graphRenderBuffer.append("L").append(edge.getLeft()).append(" --> ").append("L").append(edge.getRight()).append("\n");
graphRenderBuffer.append("L").append(edge.getLeft()).append(" --> ").append("L").append(edge.getMiddle()).append("\n");
}

return graphRenderBuffer.toString();
}

/// 1. receive two params : srcBlockId and destBlockId
/// 2. remove edge from edges and links
public void removeEdge(int srcBlockId, int destBlockId,int weight) {
edges.remove(Triple.of(srcBlockId, destBlockId,weight));
links.get(srcBlockId).getRight().remove(destBlockId);
links.get(destBlockId).getLeft().remove(srcBlockId);
}


/**
* remove edge from edges and links
* @param edge edge to remove, Triple<Integer,Integer,Integer> : <起始节点,终止节点,权重>
*
*/
public void removeEdge(Triple<Integer,Integer,Integer> edge) {
edges.remove(edge);
var srcBlockId = edge.getLeft();
var destBlockId = edge.getMiddle();
var nonRel = getInEdges(destBlockId).noneMatch(q -> q.getLeft().compareTo(srcBlockId) == 0);

if (nonRel) {
links.get(srcBlockId).getRight().remove(destBlockId);
links.get(destBlockId).getLeft().remove(srcBlockId);
}
}

/// 1. receive one param : id
public void removeNodeById(int id) {
/// 2. remove node from nodes by its id
nodes.removeIf(bb -> bb.getId() == id);
}
public void removeNode(BasicBlock<I> node) {
/// 2. remove node from nodes by its id
nodes.removeIf(bb -> bb.equals(node));
}
/// 1. accept a IFlowOptimizer<I>
/// 2. add it to optimizers
public void addOptimizer(IFlowOptimizer<I> optimizer) {
optimizers.add(optimizer);
}

/// 1. apply all optimizers
public void applyOptimizers() {
// 1. apply all optimizers
for (var optimizer : optimizers) {
optimizer.onHandle(this);
}
}


public List<I> getIRNodes() {
return StreamUtils.flatMap(nodes.stream(), BasicBlock::getIRNodes).toList();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.teachfx.antlr4.ep20.pass.cfg;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.teachfx.antlr4.ep20.ir.IRNode;
import org.teachfx.antlr4.ep20.ir.JMPInstr;
import org.teachfx.antlr4.ep20.ir.stmt.CJMP;
Expand All @@ -12,7 +13,7 @@ public class CFGBuilder {
private static final Set<String> cachedEdgeLinks = new HashSet<>();
private final CFG<IRNode> cfg;
private final List<BasicBlock<IRNode>> basicBlocks;
private final List<Pair<Integer, Integer>> edges;
private final List<Triple<Integer, Integer,Integer>> edges;

public CFGBuilder(LinearIRBlock startBlock) {
basicBlocks = new ArrayList<>();
Expand All @@ -33,25 +34,25 @@ private void build(LinearIRBlock block,Set<String> cachedEdgeLinks) {

if (lastInstr instanceof JMP jmp) {
var destOrd = jmp.getNext().getOrd();
var key = currentOrd + "-" + destOrd;
var key = currentOrd + "-" + destOrd + "-" + 5;
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, destOrd));
edges.add(Triple.of(currentOrd, destOrd,5));
}
} else if (lastInstr instanceof CJMP cjmp) {
var elseOrd = cjmp.getElseBlock().getOrd();
var key = currentOrd + "-" + elseOrd;
var key = currentOrd + "-" + elseOrd + "-" + 5;
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, elseOrd));
edges.add(Triple.of(currentOrd, elseOrd,5));
}
}

for (var successor : block.getSuccessors()){
var key = currentOrd + "-" + successor.getOrd();
var key = currentOrd + "-" + successor.getOrd() + "-" + 10;
if (!cachedEdgeLinks.contains(key)) {
cachedEdgeLinks.add(key);
edges.add(Pair.of(currentOrd, successor.getOrd()));
edges.add(Triple.of(currentOrd, successor.getOrd(),10));
}
build(successor,cachedEdgeLinks);
}
Expand Down
Loading

0 comments on commit ed60830

Please sign in to comment.