@@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
21632163 # The zero branch upcasts the output, so we can't ignore its dtype
21642164 zero_branch = constant (np .array (0 , dtype = "float64" ), name = "zero_branch" )
21652165 other_branch = scalar ("other_branch" , dtype = "float32" )
2166- outer_var = scalar ("mul_var " , dtype = "bool" )
2166+ outer_var = scalar ("outer_var " , dtype = "bool" )
21672167
21682168 out = op (switch (cond , zero_branch , other_branch ), outer_var )
21692169 fgraph = FunctionGraph (outputs = [out ], clone = False )
@@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
21732173 expected_out = switch (cond , zero_branch , op (other_branch , outer_var ))
21742174 assert equal_computations ([new_out ], [expected_out ])
21752175
2176+ @pytest .mark .parametrize (
2177+ "op, rewrite" , [(mul , local_mul_switch_sink ), (true_div , local_div_switch_sink )]
2178+ )
2179+ def test_local_mul_div_switch_sink_branch_order (self , op , rewrite ):
2180+ cond = scalar ("cond" , dtype = "bool" )
2181+ zero_branch = constant (np .array (0.0 , dtype = "float64" ), "zero_branch" )
2182+ other_branch = scalar ("other_branch" , dtype = "float64" )
2183+ outer_var = scalar ("outer_var" , dtype = "float64" )
2184+
2185+ left = op (switch (cond , zero_branch , other_branch ), outer_var )
2186+ right = op (switch (cond , other_branch , zero_branch ), outer_var )
2187+ fgraph = FunctionGraph (outputs = [left , right ], clone = False )
2188+ [new_left ] = rewrite .transform (fgraph , left .owner )
2189+ [new_right ] = rewrite .transform (fgraph , right .owner )
2190+
2191+ expected_left = switch (cond , zero_branch , op (other_branch , outer_var ))
2192+ expected_right = switch (cond , op (other_branch , outer_var ), zero_branch )
2193+ assert equal_computations (
2194+ [new_left , new_right ], [expected_left , expected_right ]
2195+ )
2196+
21762197
21772198@pytest .mark .skipif (
21782199 config .cxx == "" ,
0 commit comments