Skip to content

Commit 6d65a06

Browse files
committed
[mlir][ods] Enable granular pass registration.
Same as with pass def & decl. This doesn't change anything with registry and the big flag kept (e.g., GEN_PASS_REGISTRATION behaves like GEN_PASS_DECL and so too for sub ones).
1 parent 750a361 commit 6d65a06

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

mlir/docs/PassManagement.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,12 @@ each pass, the generator produces a `registerPassName` where
835835
generates a `registerGroupPasses`, where `Group` is the tag provided via the
836836
`-name` input parameter, that registers all of the passes present.
837837

838+
These declarations can be enabled for the whole group of passes by
839+
defining the `GEN_PASS_REGISTRATION` macro, or on a per-pass basis by
840+
defining `GEN_PASS_REGISTRATION_PASSNAME` where `PASSNAME` is the
841+
uppercase version of the name of the pass (similar to pass def and
842+
decls).
843+
838844
```c++
839845
// Tablegen options: -gen-pass-decls -name="Example"
840846

mlir/tools/mlir-tblgen/PassGen.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"(
5757
//===----------------------------------------------------------------------===//
5858
// {0} Registration
5959
//===----------------------------------------------------------------------===//
60+
#ifdef {1}
6061
6162
inline void register{0}() {{
6263
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63-
return {1};
64+
return {2};
6465
});
6566
}
6667
6768
// Old registration code, kept for temporary backwards compatibility.
6869
inline void register{0}Pass() {{
6970
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70-
return {1};
71+
return {2};
7172
});
7273
}
74+
75+
#undef {1}
76+
#endif // {1}
7377
)";
7478

7579
/// The code snippet used to generate a function to register all passes in a
@@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
116120
return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
117121
}
118122

123+
static std::string getPassRegistrationVarName(const Pass &pass) {
124+
return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper();
125+
}
126+
119127
/// Emit the code to be included in the public header of the pass.
120128
static void emitPassDecls(const Pass &pass, raw_ostream &os) {
121129
StringRef passName = pass.getDef()->getName();
@@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
143151
/// PassRegistry.
144152
static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145153
os << "#ifdef GEN_PASS_REGISTRATION\n";
154+
os << "// Generate registrations for all passes.\n";
155+
for (const Pass &pass : passes)
156+
os << "#define " << getPassRegistrationVarName(pass) << "\n";
157+
os << "#endif // GEN_PASS_REGISTRATION\n";
146158

147159
for (const Pass &pass : passes) {
160+
std::string passName = pass.getDef()->getName().str();
161+
std::string passEnableVarName = getPassRegistrationVarName(pass);
162+
148163
std::string constructorCall;
149164
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
150165
constructorCall = constructor.str();
151166
else
152-
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
153-
154-
os << formatv(passRegistrationCode, pass.getDef()->getName(),
167+
constructorCall = formatv("create{0}()", passName).str();
168+
os << formatv(passRegistrationCode, passName, passEnableVarName,
155169
constructorCall);
156170
}
157171

172+
os << "#ifdef GEN_PASS_REGISTRATION\n";
158173
os << formatv(passGroupRegistrationCode, groupName);
159174

160175
for (const Pass &pass : passes)

0 commit comments

Comments
 (0)