Skip to content

Commit

Permalink
[mlir:PassOption] Rework ListOption parsing and add support for std::…
Browse files Browse the repository at this point in the history
…vector/SmallVector options

ListOption currently uses llvm::cl::list under the hood, but the usages
of ListOption are generally a tad different from llvm::cl::list. This
commit codifies this by making ListOption implicitly comma separated,
and removes the explicit flag set for all of the current list options.
The new parsing for comma separation of ListOption also adds in support
for skipping over delimited sub-ranges (i.e. {}, [], (), "", ''). This
more easily supports nested options that use those as part of the
format, and this constraint (balanced delimiters) is already codified
in the syntax of pass pipelines.

See https://discourse.llvm.org/t/list-of-lists-pass-option/5950 for
related discussion

Differential Revision: https://reviews.llvm.org/D122879
  • Loading branch information
River707 committed Apr 2, 2022
1 parent e06ca31 commit 6edef13
Show file tree
Hide file tree
Showing 22 changed files with 237 additions and 95 deletions.
19 changes: 9 additions & 10 deletions mlir/docs/PassManagement.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,12 @@ components are integrated with the dynamic pipeline being executed.
MLIR provides a builtin mechanism for passes to specify options that configure
its behavior. These options are parsed at pass construction time independently
for each instance of the pass. Options are defined using the `Option<>` and
`ListOption<>` classes, and follow the
`ListOption<>` classes, and generally follow the
[LLVM command line](https://llvm.org/docs/CommandLine.html) flag definition
rules. See below for a few examples:
rules. One major distinction from the LLVM command line functionality is that
all `ListOption`s are comma-separated, and delimited sub-ranges within individual
elements of the list may contain commas that are not treated as separators for the
top-level list.

```c++
struct MyPass ... {
Expand All @@ -445,8 +448,7 @@ struct MyPass ... {
/// Any parameters after the description are forwarded to llvm::cl::list and
/// llvm::cl::opt respectively.
Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
ListOption<int> exampleListOption{*this, "list-flag-name",
llvm::cl::desc("...")};
ListOption<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
};
```
Expand Down Expand Up @@ -705,8 +707,7 @@ struct MyPass : PassWrapper<MyPass, OperationPass<ModuleOp>> {
llvm::cl::desc("An example option"), llvm::cl::init(true)};
ListOption<int64_t> listOption{
*this, "example-list",
llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated};
llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore};

// Specify any statistics.
Statistic statistic{this, "example-statistic", "An example statistic"};
Expand Down Expand Up @@ -742,8 +743,7 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
Option<"option", "example-option", "bool", /*default=*/"true",
"An example option">,
ListOption<"listOption", "example-list", "int64_t",
"An example list option",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
"An example list option", "llvm::cl::ZeroOrMore">
];
// Specify any statistics.
Expand Down Expand Up @@ -879,8 +879,7 @@ The `ListOption` class takes the following fields:
def MyPass : Pass<"my-pass"> {
let options = [
ListOption<"listOption", "example-list", "int64_t",
"An example list option",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
"An example list option", "llvm::cl::ZeroOrMore">
];
}
```
Expand Down
6 changes: 2 additions & 4 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,10 @@ below:
```tablegen
ListOption<"disabledPatterns", "disable-patterns", "std::string",
"Labels of patterns that should be filtered out during application",
"llvm::cl::MiscFlags::CommaSeparated">,
"Labels of patterns that should be filtered out during application">,
ListOption<"enabledPatterns", "enable-patterns", "std::string",
"Labels of patterns that should be used during application, all "
"other patterns are filtered out",
"llvm::cl::MiscFlags::CommaSeparated">,
"other patterns are filtered out">,
```
These options may be used to provide filtering behavior when constructing any
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
let options = [
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
"Specify an n-D virtual vector size for vectorization",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
"llvm::cl::ZeroOrMore">,
// Optionally, the fixed mapping from loop to fastest varying MemRef
// dimension for all the MemRefs within a loop pattern:
// the index represents the loop depth, the value represents the k^th
Expand All @@ -359,7 +359,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
"Specify a 1-D, 2-D or 3-D pattern of fastest varying memory "
"dimensions to match. See defaultPatterns in Vectorize.cpp for "
"a description and examples. This is used for testing purposes",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
"llvm::cl::ZeroOrMore">,
Option<"vectorizeReductions", "vectorize-reductions", "bool",
/*default=*/"false",
"Vectorize known reductions expressed via iter_args. "
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
"Specify if buffers should be deallocated. For compatibility with "
"core bufferization passes.">,
ListOption<"dialectFilter", "dialect-filter", "std::string",
"Restrict bufferization to ops from these dialects.",
"llvm::cl::MiscFlags::CommaSeparated">,
"Restrict bufferization to ops from these dialects.">,
Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
/*default=*/"true",
"Generate MemRef types with dynamic offset+strides by default.">,
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def LinalgTiling : Pass<"linalg-tile", "FuncOp"> {
];
let options = [
ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
"llvm::cl::ZeroOrMore">,
Option<"loopType", "loop-type", "std::string", /*default=*/"\"for\"",
"Specify the type of loops to generate: for, parallel">
];
Expand Down
12 changes: 4 additions & 8 deletions mlir/include/mlir/Dialect/SCF/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,11 @@ def SCFParallelLoopCollapsing : Pass<"scf-parallel-loop-collapsing"> {
let constructor = "mlir::createParallelLoopCollapsingPass()";
let options = [
ListOption<"clCollapsedIndices0", "collapsed-indices-0", "unsigned",
"Which loop indices to combine 0th loop index",
"llvm::cl::MiscFlags::CommaSeparated">,
"Which loop indices to combine 0th loop index">,
ListOption<"clCollapsedIndices1", "collapsed-indices-1", "unsigned",
"Which loop indices to combine into the position 1 loop index",
"llvm::cl::MiscFlags::CommaSeparated">,
"Which loop indices to combine into the position 1 loop index">,
ListOption<"clCollapsedIndices2", "collapsed-indices-2", "unsigned",
"Which loop indices to combine into the position 2 loop index",
"llvm::cl::MiscFlags::CommaSeparated">,
"Which loop indices to combine into the position 2 loop index">,
];
}

Expand All @@ -77,8 +74,7 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling", "FuncOp"> {
let constructor = "mlir::createParallelLoopTilingPass()";
let options = [
ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
"Factors to tile parallel loops by",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
"Factors to tile parallel loops by", "llvm::cl::ZeroOrMore">,
Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
/*default=*/"false",
"Perform tiling with fixed upper bound with inbound check "
Expand Down
160 changes: 146 additions & 14 deletions mlir/include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,63 @@

#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include <memory>

namespace mlir {
namespace detail {
namespace pass_options {
/// Parse a string containing a list of comma-delimited elements, invoking the
/// given parser for each sub-element and passing them to the provided
/// element-append functor.
LogicalResult
parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
StringRef optionStr,
function_ref<LogicalResult(StringRef)> elementParseFn);
template <typename ElementParser, typename ElementAppendFn>
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
StringRef optionStr,
ElementParser &elementParser,
ElementAppendFn &&appendFn) {
return parseCommaSeparatedList(
opt, argName, optionStr, [&](StringRef valueStr) {
typename ElementParser::parser_data_type value = {};
if (elementParser.parse(opt, argName, valueStr, value))
return failure();
appendFn(value);
return success();
});
}

/// Trait used to detect if a type has a operator<< method.
template <typename T>
using has_stream_operator_trait =
decltype(std::declval<raw_ostream &>() << std::declval<T>());
template <typename T>
using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;

/// Utility methods for printing option values.
template <typename ParserT>
static void printOptionValue(raw_ostream &os, const bool &value) {
os << (value ? StringRef("true") : StringRef("false"));
}
template <typename ParserT, typename DataT>
static std::enable_if_t<has_stream_operator<DataT>::value>
printOptionValue(raw_ostream &os, const DataT &value) {
os << value;
}
template <typename ParserT, typename DataT>
static std::enable_if_t<!has_stream_operator<DataT>::value>
printOptionValue(raw_ostream &os, const DataT &value) {
// If the value can't be streamed, fallback to checking for a print in the
// parser.
ParserT::print(os, value);
}
} // namespace pass_options

/// Base container class and manager for all pass options.
class PassOptions : protected llvm::cl::SubCommand {
private:
Expand Down Expand Up @@ -85,11 +135,7 @@ class PassOptions : protected llvm::cl::SubCommand {
}
template <typename DataT, typename ParserT>
static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
os << value;
}
template <typename ParserT>
static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
os << (value ? StringRef("true") : StringRef("false"));
detail::pass_options::printOptionValue<ParserT>(os, value);
}

public:
Expand Down Expand Up @@ -149,22 +195,27 @@ class PassOptions : protected llvm::cl::SubCommand {
};

/// This class represents a specific pass option that contains a list of
/// values of the provided data type.
/// values of the provided data type. The elements within the textual form of
/// this option are parsed assuming they are comma-separated. Delimited
/// sub-ranges within individual elements of the list may contain commas that
/// are not treated as separators for the top-level list.
template <typename DataType, typename OptionParser = OptionParser<DataType>>
class ListOption
: public llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>,
public OptionBase {
public:
template <typename... Args>
ListOption(PassOptions &parent, StringRef arg, Args &&... args)
ListOption(PassOptions &parent, StringRef arg, Args &&...args)
: llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>(
arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
elementParser(*this) {
assert(!this->isPositional() && !this->isSink() &&
"sink and positional options are not supported");
assert(!(this->getMiscFlags() & llvm::cl::MiscFlags::CommaSeparated) &&
"ListOption is implicitly comma separated, specifying "
"CommaSeparated is extraneous");
parent.options.push_back(this);

// Set a callback to track if this option has a value.
this->setCallback([this](const auto &) { this->optHasValue = true; });
elementParser.initialize();
}
~ListOption() override = default;
ListOption<DataType, OptionParser> &
Expand All @@ -174,6 +225,14 @@ class PassOptions : protected llvm::cl::SubCommand {
return *this;
}

bool handleOccurrence(unsigned pos, StringRef argName,
StringRef arg) override {
this->optHasValue = true;
return failed(detail::pass_options::parseCommaSeparatedList(
*this, argName, arg, elementParser,
[&](const DataType &value) { this->addValue(value); }));
}

/// Allow assigning from an ArrayRef.
ListOption<DataType, OptionParser> &operator=(ArrayRef<DataType> values) {
((std::vector<DataType> &)*this).assign(values.begin(), values.end());
Expand Down Expand Up @@ -211,6 +270,9 @@ class PassOptions : protected llvm::cl::SubCommand {
void copyValueFrom(const OptionBase &other) final {
*this = static_cast<const ListOption<DataType, OptionParser> &>(other);
}

/// The parser to use for parsing the list elements.
OptionParser elementParser;
};

PassOptions() = default;
Expand Down Expand Up @@ -255,9 +317,7 @@ class PassOptions : protected llvm::cl::SubCommand {
/// Usage:
///
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
/// ListOption<int> someListFlag{
/// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
/// llvm::cl::desc("...")};
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
/// };
template <typename T> class PassPipelineOptions : public detail::PassOptions {
public:
Expand All @@ -278,5 +338,77 @@ struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {

} // namespace mlir

//===----------------------------------------------------------------------===//
// MLIR Options
//===----------------------------------------------------------------------===//

namespace llvm {
namespace cl {
//===----------------------------------------------------------------------===//
// std::vector+SmallVector

namespace detail {
template <typename VectorT, typename ElementT>
class VectorParserBase : public basic_parser_impl {
public:
VectorParserBase(Option &opt) : basic_parser_impl(opt), elementParser(opt) {}

using parser_data_type = VectorT;

bool parse(Option &opt, StringRef argName, StringRef arg,
parser_data_type &vector) {
if (!arg.consume_front("[") || !arg.consume_back("]")) {
return opt.error("expected vector option to be wrapped with '[]'",
argName);
}

return failed(mlir::detail::pass_options::parseCommaSeparatedList(
opt, argName, arg, elementParser,
[&](const ElementT &value) { vector.push_back(value); }));
}

static void print(raw_ostream &os, const VectorT &vector) {
llvm::interleave(
vector, os,
[&](const ElementT &value) {
mlir::detail::pass_options::printOptionValue<
llvm::cl::parser<ElementT>>(os, value);
},
",");
}

void printOptionInfo(const Option &opt, size_t globalWidth) const {
// Add the `vector<>` qualifier to the option info.
outs() << " --" << opt.ArgStr;
outs() << "=<vector<" << elementParser.getValueName() << ">>";
Option::printHelpStr(opt.HelpStr, globalWidth, getOptionWidth(opt));
}

size_t getOptionWidth(const Option &opt) const {
// Add the `vector<>` qualifier to the option width.
StringRef vectorExt("vector<>");
return elementParser.getOptionWidth(opt) + vectorExt.size();
}

private:
llvm::cl::parser<ElementT> elementParser;
};
} // namespace detail

template <typename T>
class parser<std::vector<T>>
: public detail::VectorParserBase<std::vector<T>, T> {
public:
parser(Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
};
template <typename T, unsigned N>
class parser<SmallVector<T, N>>
: public detail::VectorParserBase<SmallVector<T, N>, T> {
public:
parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
};
} // end namespace cl
} // end namespace llvm

#endif // MLIR_PASS_PASSOPTIONS_H_

3 changes: 1 addition & 2 deletions mlir/include/mlir/Reducer/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def CommonReductionPassOptions {
Option<"testerName", "test", "std::string", /* default */"",
"The location of the tester which tests the file interestingness">,
ListOption<"testerArgs", "test-arg", "std::string",
"arguments of the tester",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
"arguments of the tester", "llvm::cl::ZeroOrMore">,
];
}

Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Rewrite/PassUtil.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ def RewritePassUtils {
// created.
ListOption<"disabledPatterns", "disable-patterns", "std::string",
"Labels of patterns that should be filtered out during"
" application",
"llvm::cl::MiscFlags::CommaSeparated">,
" application">,
ListOption<"enabledPatterns", "enable-patterns", "std::string",
"Labels of patterns that should be used during"
" application, all other patterns are filtered out",
"llvm::cl::MiscFlags::CommaSeparated">,
" application, all other patterns are filtered out">,
];
}

Expand Down
Loading

0 comments on commit 6edef13

Please sign in to comment.