Skip to content

Commit

Permalink
[mlir] Allow for using OpPassManager in pass options
Browse files Browse the repository at this point in the history
This significantly simplifies the boilerplate necessary for passes
to define nested pass pipelines.

Differential Revision: https://reviews.llvm.org/D122880
  • Loading branch information
River707 committed Apr 2, 2022
1 parent 6edef13 commit 0d8df98
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 37 deletions.
5 changes: 4 additions & 1 deletion mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class OpPassManager {
return {begin(), end()};
}

/// Returns true if the pass manager has no passes.
bool empty() const { return begin() == end(); }

/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(StringAttr nestedName);
Expand Down Expand Up @@ -110,7 +113,7 @@ class OpPassManager {
/// of pipelines.
/// Note: The quality of the string representation depends entirely on the
/// the correctness of per-pass overrides of Pass::printAsTextualPipeline.
void printAsTextualPipeline(raw_ostream &os);
void printAsTextualPipeline(raw_ostream &os) const;

/// Raw dump of the pass manager to llvm::errs().
void dump();
Expand Down
98 changes: 92 additions & 6 deletions mlir/include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <memory>

namespace mlir {
class OpPassManager;

namespace detail {
namespace pass_options {
/// Parse a string containing a list of comma-delimited elements, invoking the
Expand Down Expand Up @@ -158,7 +160,7 @@ class PassOptions : protected llvm::cl::SubCommand {
public OptionBase {
public:
template <typename... Args>
Option(PassOptions &parent, StringRef arg, Args &&... args)
Option(PassOptions &parent, StringRef arg, Args &&...args)
: llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
assert(!this->isPositional() && !this->isSink() &&
Expand Down Expand Up @@ -319,7 +321,8 @@ class PassOptions : protected llvm::cl::SubCommand {
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
/// };
template <typename T> class PassPipelineOptions : public detail::PassOptions {
template <typename T>
class PassPipelineOptions : public detail::PassOptions {
public:
/// Factory that parses the provided options and returns a unique_ptr to the
/// struct.
Expand All @@ -335,7 +338,6 @@ template <typename T> class PassPipelineOptions : public detail::PassOptions {
/// any options.
struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
};

} // namespace mlir

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -407,8 +409,92 @@ class parser<SmallVector<T, N>>
public:
parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
};
} // end namespace cl
} // end namespace llvm

#endif // MLIR_PASS_PASSOPTIONS_H_
//===----------------------------------------------------------------------===//
// OpPassManager: OptionValue

template <>
struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
using WrapperType = mlir::OpPassManager;

OptionValue();
OptionValue(const mlir::OpPassManager &value);
OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
~OptionValue();

/// Returns if the current option has a value.
bool hasValue() const { return value.get(); }

/// Returns the current value of the option.
mlir::OpPassManager &getValue() const {
assert(hasValue() && "invalid option value");
return *value;
}

/// Set the value of the option.
void setValue(const mlir::OpPassManager &newValue);
void setValue(StringRef pipelineStr);

/// Compare the option with the provided value.
bool compare(const mlir::OpPassManager &rhs) const;
bool compare(const GenericOptionValue &rhs) const override {
const auto &rhsOV =
static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
if (!rhsOV.hasValue())
return false;
return compare(rhsOV.getValue());
}

private:
void anchor() override;

/// The underlying pass manager. We use a unique_ptr to avoid the need for the
/// full type definition.
std::unique_ptr<mlir::OpPassManager> value;
};

//===----------------------------------------------------------------------===//
// OpPassManager: Parser

extern template class basic_parser<mlir::OpPassManager>;

template <>
class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
public:
/// A utility struct used when parsing a pass manager that prevents the need
/// for a default constructor on OpPassManager.
struct ParsedPassManager {
ParsedPassManager();
ParsedPassManager(ParsedPassManager &&);
~ParsedPassManager();
operator const mlir::OpPassManager &() const {
assert(value && "parsed value was invalid");
return *value;
}

std::unique_ptr<mlir::OpPassManager> value;
};
using parser_data_type = ParsedPassManager;
using OptVal = OptionValue<mlir::OpPassManager>;

parser(Option &opt) : basic_parser(opt) {}

bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);

/// Print an instance of the underling option value to the given stream.
static void print(raw_ostream &os, const mlir::OpPassManager &value);

// Overload in subclass to provide a better default value.
StringRef getValueName() const override { return "pass-manager"; }

void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
const OptVal &defaultValue, size_t globalWidth) const;

// An out-of-line virtual method to provide a 'home' for this class.
void anchor() override;
};

} // namespace cl
} // namespace llvm

#endif // MLIR_PASS_PASSOPTIONS_H_
2 changes: 1 addition & 1 deletion mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def Inliner : Pass<"inline"> {
let options = [
Option<"defaultPipelineStr", "default-pipeline", "std::string",
/*default=*/"", "The default optimizer pipeline used for callables">,
ListOption<"opPipelineStrs", "op-pipelines", "std::string",
ListOption<"opPipelineList", "op-pipelines", "OpPassManager",
"Callable operation specific optimizer pipelines (in the form "
"of `dialect.op(pipeline)`)">,
Option<"maxInliningIterations", "max-iterations", "unsigned",
Expand Down
25 changes: 15 additions & 10 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
void Pass::printAsTextualPipeline(raw_ostream &os) {
// Special case for adaptors to use the 'op_name(sub_passes)' format.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
llvm::interleaveComma(adaptor->getPassManagers(), os,
[&](OpPassManager &pm) {
os << pm.getOpName() << "(";
pm.printAsTextualPipeline(os);
os << ")";
});
llvm::interleave(
adaptor->getPassManagers(),
[&](OpPassManager &pm) {
os << pm.getOpName() << "(";
pm.printAsTextualPipeline(os);
os << ")";
},
[&] { os << ","; });
return;
}
// Otherwise, print the pass argument followed by its options. If the pass
Expand Down Expand Up @@ -295,14 +297,17 @@ OperationName OpPassManager::getOpName(MLIRContext &context) const {
/// Prints out the given passes as the textual representation of a pipeline.
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
raw_ostream &os) {
llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
pass->printAsTextualPipeline(os);
});
llvm::interleave(
passes,
[&](const std::unique_ptr<Pass> &pass) {
pass->printAsTextualPipeline(os);
},
[&] { os << ","; });
}

/// Prints out the passes of the pass manager as the textual representation
/// of pipelines.
void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
::printAsTextualPipeline(impl->passes, os);
}

Expand Down
98 changes: 98 additions & 0 deletions mlir/lib/Pass/PassRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,104 @@ size_t detail::PassOptions::getOptionWidth() const {
return max;
}

//===----------------------------------------------------------------------===//
// MLIR Options
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
// OpPassManager: OptionValue

llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
llvm::cl::OptionValue<OpPassManager>::OptionValue(
const mlir::OpPassManager &value) {
setValue(value);
}
llvm::cl::OptionValue<OpPassManager> &
llvm::cl::OptionValue<OpPassManager>::operator=(
const mlir::OpPassManager &rhs) {
setValue(rhs);
return *this;
}

llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;

void llvm::cl::OptionValue<OpPassManager>::setValue(
const OpPassManager &newValue) {
if (hasValue())
*value = newValue;
else
value = std::make_unique<mlir::OpPassManager>(newValue);
}
void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
assert(succeeded(pipeline) && "invalid pass pipeline");
setValue(*pipeline);
}

bool llvm::cl::OptionValue<OpPassManager>::compare(
const mlir::OpPassManager &rhs) const {
std::string lhsStr, rhsStr;
{
raw_string_ostream lhsStream(lhsStr);
value->printAsTextualPipeline(lhsStream);

raw_string_ostream rhsStream(rhsStr);
rhs.printAsTextualPipeline(rhsStream);
}

// Use the textual format for pipeline comparisons.
return lhsStr == rhsStr;
}

void llvm::cl::OptionValue<OpPassManager>::anchor() {}

//===----------------------------------------------------------------------===//
// OpPassManager: Parser

namespace llvm {
namespace cl {
template class basic_parser<OpPassManager>;
} // namespace cl
} // namespace llvm

bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
ParsedPassManager &value) {
FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
if (failed(pipeline))
return true;
value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
return false;
}

void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
const OpPassManager &value) {
value.printAsTextualPipeline(os);
}

void llvm::cl::parser<OpPassManager>::printOptionDiff(
const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
size_t globalWidth) const {
printOptionName(opt, globalWidth);
outs() << "= ";
pm.printAsTextualPipeline(outs());

if (defaultValue.hasValue()) {
outs().indent(2) << " (default: ";
defaultValue.getValue().printAsTextualPipeline(outs());
outs() << ")";
}
outs() << "\n";
}

void llvm::cl::parser<OpPassManager>::anchor() {}

llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
default;
llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
ParsedPassManager &&) = default;
llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
default;

//===----------------------------------------------------------------------===//
// TextualPassPipeline Parser
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 5 additions & 17 deletions mlir/lib/Transforms/Inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,8 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
return;

// Update the option for the op specific optimization pipelines.
for (auto &it : opPipelines) {
std::string pipeline;
llvm::raw_string_ostream pipelineOS(pipeline);
pipelineOS << it.getKey() << "(";
it.second.printAsTextualPipeline(pipelineOS);
pipelineOS << ")";
opPipelineStrs.addValue(pipeline);
}
for (auto &it : opPipelines)
opPipelineList.addValue(it.second);
this->opPipelines.emplace_back(std::move(opPipelines));
}

Expand Down Expand Up @@ -751,15 +745,9 @@ LogicalResult InlinerPass::initializeOptions(StringRef options) {

// Initialize the op specific pass pipelines.
llvm::StringMap<OpPassManager> pipelines;
for (StringRef pipeline : opPipelineStrs) {
// Skip empty pipelines.
if (pipeline.empty())
continue;
FailureOr<OpPassManager> pm = parsePassPipeline(pipeline);
if (failed(pm))
return failure();
pipelines.try_emplace(pm->getOpName(), std::move(*pm));
}
for (OpPassManager pipeline : opPipelineList)
if (!pipeline.empty())
pipelines.try_emplace(pipeline.getOpName(), pipeline);
opPipelines.assign({std::move(pipelines)});

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define TRANSFORMS_PASSDETAIL_H_

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Pass/crash-recovery.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module @inner_mod1 {
module @foo {}
}

// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass, test-pass-crash)'
// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)'

// REPRO: module @inner_mod1
// REPRO: module @foo {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Pass/pipeline-options-parsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

// CHECK_1: test-options-pass{list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }), func.func(test-options-pass{list=1,2,3,4 string= }))
// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }),func.func(test-options-pass{list=1,2,3,4 string= }))
1 change: 1 addition & 0 deletions mlir/test/Transforms/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
// RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY

// Inline a function that takes an argument.
func @func_with_arg(%c : i32) -> i32 {
Expand Down

0 comments on commit 0d8df98

Please sign in to comment.