@@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"(
5757//===----------------------------------------------------------------------===//
5858// {0} Registration
5959//===----------------------------------------------------------------------===//
60+ #ifdef {1}
6061
6162inline void register{0}() {{
6263 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63- return {1 };
64+ return {2 };
6465 });
6566}
6667
6768// Old registration code, kept for temporary backwards compatibility.
6869inline void register{0}Pass() {{
6970 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70- return {1 };
71+ return {2 };
7172 });
7273}
74+
75+ #undef {1}
76+ #endif // {1}
7377)" ;
7478
7579// / The code snippet used to generate a function to register all passes in a
@@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
116120 return " GEN_PASS_DECL_" + pass.getDef ()->getName ().upper ();
117121}
118122
123+ static std::string getPassRegistrationVarName (const Pass &pass) {
124+ return " GEN_PASS_REGISTRATION_" + pass.getDef ()->getName ().upper ();
125+ }
126+
119127// / Emit the code to be included in the public header of the pass.
120128static void emitPassDecls (const Pass &pass, raw_ostream &os) {
121129 StringRef passName = pass.getDef ()->getName ();
@@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
143151// / PassRegistry.
144152static void emitRegistrations (llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145153 os << " #ifdef GEN_PASS_REGISTRATION\n " ;
154+ os << " // Generate registrations for all passes.\n " ;
155+ for (const Pass &pass : passes)
156+ os << " #define " << getPassRegistrationVarName (pass) << " \n " ;
157+ os << " #endif // GEN_PASS_REGISTRATION\n " ;
146158
147159 for (const Pass &pass : passes) {
160+ std::string passName = pass.getDef ()->getName ().str ();
161+ std::string passEnableVarName = getPassRegistrationVarName (pass);
162+
148163 std::string constructorCall;
149164 if (StringRef constructor = pass.getConstructor (); !constructor.empty ())
150165 constructorCall = constructor.str ();
151166 else
152- constructorCall = formatv (" create{0}()" , pass.getDef ()->getName ()).str ();
153-
154- os << formatv (passRegistrationCode, pass.getDef ()->getName (),
167+ constructorCall = formatv (" create{0}()" , passName).str ();
168+ os << formatv (passRegistrationCode, passName, passEnableVarName,
155169 constructorCall);
156170 }
157171
172+ os << " #ifdef GEN_PASS_REGISTRATION\n " ;
158173 os << formatv (passGroupRegistrationCode, groupName);
159174
160175 for (const Pass &pass : passes)
0 commit comments