@@ -58,18 +58,23 @@ const char *const passRegistrationCode = R"(
5858// {0} Registration
5959//===----------------------------------------------------------------------===//
6060
61+ #if defined(GEN_PASS_REGISTRATION) || defined({1})
62+
6163inline void register{0}() {{
6264 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63- return {1 };
65+ return {2 };
6466 });
6567}
6668
6769// Old registration code, kept for temporary backwards compatibility.
6870inline void register{0}Pass() {{
6971 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70- return {1 };
72+ return {2 };
7173 });
7274}
75+
76+ #undef {1}
77+ #endif // GEN_PASS_REGISTRATION || {1}
7378)" ;
7479
7580// / The code snippet used to generate a function to register all passes in a
@@ -116,6 +121,10 @@ static std::string getPassDeclVarName(const Pass &pass) {
116121 return " GEN_PASS_DECL_" + pass.getDef ()->getName ().upper ();
117122}
118123
124+ static std::string getPassRegistrationVarName (const Pass &pass) {
125+ return " GEN_PASS_REGISTRATION_" + pass.getDef ()->getName ().upper ();
126+ }
127+
119128// / Emit the code to be included in the public header of the pass.
120129static void emitPassDecls (const Pass &pass, raw_ostream &os) {
121130 StringRef passName = pass.getDef ()->getName ();
@@ -142,19 +151,20 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
142151// / Emit the code for registering each of the given passes with the global
143152// / PassRegistry.
144153static void emitRegistrations (llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145- os << " #ifdef GEN_PASS_REGISTRATION\n " ;
146-
147154 for (const Pass &pass : passes) {
155+ std::string passName = pass.getDef ()->getName ().str ();
156+ std::string passEnableVarName = getPassRegistrationVarName (pass);
157+
148158 std::string constructorCall;
149159 if (StringRef constructor = pass.getConstructor (); !constructor.empty ())
150160 constructorCall = constructor.str ();
151161 else
152- constructorCall = formatv (" create{0}()" , pass.getDef ()->getName ()).str ();
153-
154- os << formatv (passRegistrationCode, pass.getDef ()->getName (),
162+ constructorCall = formatv (" create{0}()" , passName).str ();
163+ os << formatv (passRegistrationCode, passName, passEnableVarName,
155164 constructorCall);
156165 }
157166
167+ os << " #ifdef GEN_PASS_REGISTRATION\n " ;
158168 os << formatv (passGroupRegistrationCode, groupName);
159169
160170 for (const Pass &pass : passes)
0 commit comments