Skip to content

Commit e093e10

Browse files
committed
[mlir][ods] Enable granular pass registration.
Same as with pass def & decl. This doesn't change anything with registry and the big flag kept (e.g., GEN_PASS_REGISTRATION behaves like GEN_PASS_DECL and so too for sub ones).
1 parent 750a361 commit e093e10

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

mlir/docs/PassManagement.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,12 @@ each pass, the generator produces a `registerPassName` where
835835
generates a `registerGroupPasses`, where `Group` is the tag provided via the
836836
`-name` input parameter, that registers all of the passes present.
837837

838+
These declarations can be enabled for the whole group of passes by
839+
defining the `GEN_PASS_REGISTRATION` macro, or on a per-pass basis by
840+
defining `GEN_PASS_REGISTRATION_PASSNAME` where `PASSNAME` is the
841+
uppercase version of the name of the pass (similar to pass def and
842+
decls).
843+
838844
```c++
839845
// Tablegen options: -gen-pass-decls -name="Example"
840846

mlir/tools/mlir-tblgen/PassGen.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,23 @@ const char *const passRegistrationCode = R"(
5858
// {0} Registration
5959
//===----------------------------------------------------------------------===//
6060
61+
#if defined(GEN_PASS_REGISTRATION) || defined({1})
62+
6163
inline void register{0}() {{
6264
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63-
return {1};
65+
return {2};
6466
});
6567
}
6668
6769
// Old registration code, kept for temporary backwards compatibility.
6870
inline void register{0}Pass() {{
6971
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70-
return {1};
72+
return {2};
7173
});
7274
}
75+
76+
#undef {1}
77+
#endif // GEN_PASS_REGISTRATION || {1}
7378
)";
7479

7580
/// The code snippet used to generate a function to register all passes in a
@@ -116,6 +121,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
116121
return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
117122
}
118123

124+
static std::string getPassRegistrationVarName(const Pass &pass) {
125+
return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
126+
}
127+
119128
/// Emit the code to be included in the public header of the pass.
120129
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
121130
StringRef passName = pass.getDef()->getName();
@@ -142,19 +151,20 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
142151
/// Emit the code for registering each of the given passes with the global
143152
/// PassRegistry.
144153
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145-
os << "#ifdef GEN_PASS_REGISTRATION\n";
146-
147154
for (const Pass &pass : passes) {
155+
std::string passName = pass.getDef()->getName().str();
156+
std::string passEnableVarName = getPassRegistrationVarName(pass);
157+
148158
std::string constructorCall;
149159
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
150160
constructorCall = constructor.str();
151161
else
152-
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
153-
154-
os << formatv(passRegistrationCode, pass.getDef()->getName(),
162+
constructorCall = formatv("create{0}()", passName).str();
163+
os << formatv(passRegistrationCode, passName, passEnableVarName,
155164
constructorCall);
156165
}
157166

167+
os << "#ifdef GEN_PASS_REGISTRATION\n";
158168
os << formatv(passGroupRegistrationCode, groupName);
159169

160170
for (const Pass &pass : passes)

0 commit comments

Comments
 (0)