Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

upgrade op fusion lowering #1216

Open
wants to merge 36 commits into
base: develop
Choose a base branch
from

Conversation

SunNy820828449
Copy link
Collaborator

@SunNy820828449 SunNy820828449 commented Feb 22, 2023

这个PR主要是对融合算子lowering到ast进行升级,增加扩展性和兼容性,能够适配更加复杂的融合算子生成。
将elemenwise/kinjective/kbroadcast/reduce的循环融合放在一起。
此外,删除了旧的调度原语上的循环融合。

Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

大哥,你这PR好像做了很多非op lowering重构的工作啊,可以拆成多个PR么?

@@ -128,6 +128,7 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)

backends::nvrtc::Compiler compiler;

VLOG(3) << "[CUDA] device code:\n" << source_code;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SourceCodePrint就会打印这个源码,而且可以保证一个程序会将所有子图打印出来

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@@ -2016,6 +2016,7 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
};

{
/*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

冗余代码应该删除而非注释,而且注释应该统一用//而非/* */

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -362,7 +362,7 @@ TEST(CAS, SimplifyMinMax) {
LOG(INFO) << "p0 " << p0;
auto p2 = CasSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))");
// EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是单测过不了了么?还是改之后结果是对的,但不是这样了?如果是前者那不能简单的注释掉啊,如果是后者那把这改为正确的值不就行了么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -114,6 +114,14 @@ class Graph : public cinn::common::Graph {
}
}

std::unordered_set<Node*> NodeSet() {
std::unordered_set<Node*> node_set;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这函数功能完全和CollectNodes重复了而且也不常用啊。。。在用的地方直接定义

const auto& nodes = group->CollectNodes();
std::unordered_set<Node*> node_set(nodes.begin(), nodes.end());

不好么。。。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用的地方比较多 这里实现比较方便

@@ -101,7 +65,7 @@ std::vector<ir::LoweredFunc> OpLowerer::LowerWithoutSchedule(GroupPtr& group) {
LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!";
}
} else {
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
LOG(FATAL) << "Previous IR Schedule Is Unsupport Now!";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LOG(FATAL) << "Previous IR Schedule Is Unsupport Now!";
LOG(FATAL) << "Previous IR Schedule Unsupported Now, Please set FLAGS_cinn_ir_schedule=1 to use new IR Schedule";

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
if (op_pattern_dict[node->op()] == framework::kElementWise) {
ir_sch.FlattenLoops(loops, true);
} else if (op_pattern_dict[node->op()] == framework::kReduction) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这if里啥都没有啊?那为啥不直接

else if (op_pattern_dict[node->op()] != framework::kReduction) {
    ir_sch.FlattenLoops(loops, false);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return tensors;
}

NodeData* GetNodeData(const Node* node) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这函数名名不副实啊。。。

Suggested change
NodeData* GetNodeData(const Node* node) {
NodeData* GetFirstOutputNodeData(const Node* node) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉这个名字没啥问题,而且用了很久了。

return node_data;
}

std::vector<NodeData*> GetAllNodeData(const Node* node) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<NodeData*> GetAllNodeData(const Node* node) {
std::vector<NodeData*> GetAllOutputNodeDatas(const Node* node) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

cinn/hlir/pass/const_propagate_test.cc Outdated Show resolved Hide resolved
cinn/hlir/pass/const_propagate_test.cc Outdated Show resolved Hide resolved
@SunNy820828449
Copy link
Collaborator Author

大哥,你这PR好像做了很多非op lowering重构的工作啊,可以拆成多个PR么?

主要是op_loweing的,其他的工作只有一点点。

thisjiang
thisjiang previously approved these changes Mar 10, 2023
Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1);
// do schedule
for (auto node : nodes_in_order) {
LOG(INFO) << node->id();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(INFO)。。。

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants