Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/docs/PassManagement.md
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,12 @@ each pass, the generator produces a `registerPassName` where
generates a `registerGroupPasses`, where `Group` is the tag provided via the
`-name` input parameter, that registers all of the passes present.

These declarations can be enabled for the whole group of passes by
defining the `GEN_PASS_REGISTRATION` macro, or on a per-pass basis by
defining `GEN_PASS_REGISTRATION_PASSNAME` where `PASSNAME` is the
uppercase version of the name of the pass (similar to pass def and
decls).

```c++
// Tablegen options: -gen-pass-decls -name="Example"

Expand Down
25 changes: 20 additions & 5 deletions mlir/tools/mlir-tblgen/PassGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
#ifdef {1}

inline void register{0}() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
return {2};
});
}

// Old registration code, kept for temporary backwards compatibility.
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
return {2};
});
}

#undef {1}
#endif // {1}
)";

/// The code snippet used to generate a function to register all passes in a
Expand Down Expand Up @@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
}

static std::string getPassRegistrationVarName(const Pass &pass) {
return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
}

/// Emit the code to be included in the public header of the pass.
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
Expand Down Expand Up @@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
/// PassRegistry.
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
os << "// Generate registrations for all passes.\n";
for (const Pass &pass : passes)
os << "#define " << getPassRegistrationVarName(pass) << "\n";
os << "#endif // GEN_PASS_REGISTRATION\n";

for (const Pass &pass : passes) {
std::string passName = pass.getDef()->getName().str();
std::string passEnableVarName = getPassRegistrationVarName(pass);

std::string constructorCall;
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();

os << formatv(passRegistrationCode, pass.getDef()->getName(),
constructorCall = formatv("create{0}()", passName).str();
os << formatv(passRegistrationCode, passName, passEnableVarName,
constructorCall);
}

os << "#ifdef GEN_PASS_REGISTRATION\n";
os << formatv(passGroupRegistrationCode, groupName);

for (const Pass &pass : passes)
Expand Down