1515from pytensor .tensor .elemwise import DimShuffle
1616from pytensor .tensor .rewriting .basic import register_specialize
1717from pytensor .tensor .rewriting .linalg import is_matrix_transpose
18- from pytensor .tensor .slinalg import Solve , lu_factor , lu_solve
18+ from pytensor .tensor .slinalg import Solve , cho_solve , cholesky , lu_factor , lu_solve
1919from pytensor .tensor .variable import TensorVariable
2020
2121
22- def decompose_A (A , assume_a , check_finite ):
22+ def decompose_A (A , assume_a , check_finite , lower ):
2323 if assume_a == "gen" :
2424 return lu_factor (A , check_finite = check_finite )
2525 elif assume_a == "tridiagonal" :
2626 # We didn't implement check_finite for tridiagonal LU factorization
2727 return tridiagonal_lu_factor (A )
28+ elif assume_a == "pos" :
29+ return cholesky (A , lower = lower , check_finite = check_finite )
2830 else :
2931 raise NotImplementedError
3032
3133
32- def solve_lu_decomposed_system (A_decomp , b , transposed = False , * , core_solve_op : Solve ):
34+ def solve_decomposed_system (
35+ A_decomp , b , transposed = False , lower = False , * , core_solve_op : Solve
36+ ):
3337 b_ndim = core_solve_op .b_ndim
3438 check_finite = core_solve_op .check_finite
3539 assume_a = core_solve_op .assume_a
40+
3641 if assume_a == "gen" :
3742 return lu_solve (
3843 A_decomp ,
@@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
4954 b_ndim = b_ndim ,
5055 transposed = transposed ,
5156 )
57+ elif assume_a == "pos" :
58+ # We can ignore the transposed argument here because A is symmetric by assumption
59+ return cho_solve (
60+ (A_decomp , lower ),
61+ b ,
62+ b_ndim = b_ndim ,
63+ check_finite = check_finite ,
64+ )
5265 else :
5366 raise NotImplementedError
5467
5568
56- def _split_lu_solve_steps (
69+ def _split_decomp_and_solve_steps (
5770 fgraph , node , * , eager : bool , allowed_assume_a : Container [str ]
5871):
5972 if not isinstance (node .op .core_op , Solve ):
@@ -133,13 +146,21 @@ def find_solve_clients(var, assume_a):
133146 if client .op .core_op .check_finite :
134147 check_finite_decomp = True
135148 break
136- A_decomp = decompose_A (A , assume_a = assume_a , check_finite = check_finite_decomp )
149+
150+ lower = node .op .core_op .lower
151+ A_decomp = decompose_A (
152+ A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
153+ )
137154
138155 replacements = {}
139156 for client , transposed in A_solve_clients_and_transpose :
140157 _ , b = client .inputs
141- new_x = solve_lu_decomposed_system (
142- A_decomp , b , transposed = transposed , core_solve_op = client .op .core_op
158+ new_x = solve_decomposed_system (
159+ A_decomp ,
160+ b ,
161+ transposed = transposed ,
162+ lower = lower ,
163+ core_solve_op = client .op .core_op ,
143164 )
144165 [old_x ] = client .outputs
145166 new_x = atleast_Nd (new_x , n = old_x .type .ndim ).astype (old_x .type .dtype )
@@ -149,7 +170,7 @@ def find_solve_clients(var, assume_a):
149170 return replacements
150171
151172
152- def _scan_split_non_sequence_lu_decomposition_solve (
173+ def _scan_split_non_sequence_decomposition_and_solve (
153174 fgraph , node , * , allowed_assume_a : Container [str ]
154175):
155176 """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
@@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
179200 non_sequences = {equiv [non_seq ] for non_seq in non_sequences }
180201 inner_node = equiv [inner_node ] # type: ignore
181202
182- replace_dict = _split_lu_solve_steps (
203+ replace_dict = _split_decomp_and_solve_steps (
183204 new_scan_fgraph ,
184205 inner_node ,
185206 eager = True ,
@@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(
207228
208229@register_specialize
209230@node_rewriter ([Blockwise ])
210- def reuse_lu_decomposition_multiple_solves (fgraph , node ):
211- return _split_lu_solve_steps (
212- fgraph , node , eager = False , allowed_assume_a = {"gen" , "tridiagonal" }
231+ def reuse_decomposition_multiple_solves (fgraph , node ):
232+ return _split_decomp_and_solve_steps (
233+ fgraph , node , eager = False , allowed_assume_a = {"gen" , "tridiagonal" , "pos" }
213234 )
214235
215236
216237@node_rewriter ([Scan ])
217- def scan_split_non_sequence_lu_decomposition_solve (fgraph , node ):
218- return _scan_split_non_sequence_lu_decomposition_solve (
219- fgraph , node , allowed_assume_a = {"gen" , "tridiagonal" }
238+ def scan_split_non_sequence_decomposition_and_solve (fgraph , node ):
239+ return _scan_split_non_sequence_decomposition_and_solve (
240+ fgraph , node , allowed_assume_a = {"gen" , "tridiagonal" , "pos" }
220241 )
221242
222243
223244scan_seqopt1 .register (
224- "scan_split_non_sequence_lu_decomposition_solve" ,
225- in2out (scan_split_non_sequence_lu_decomposition_solve , ignore_newtrees = True ),
245+ scan_split_non_sequence_decomposition_and_solve . __name__ ,
246+ in2out (scan_split_non_sequence_decomposition_and_solve , ignore_newtrees = True ),
226247 "fast_run" ,
227248 "scan" ,
228249 "scan_pushout" ,
@@ -231,28 +252,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
231252
232253
233254@node_rewriter ([Blockwise ])
234- def reuse_lu_decomposition_multiple_solves_jax (fgraph , node ):
235- return _split_lu_solve_steps (fgraph , node , eager = False , allowed_assume_a = {"gen" })
255+ def reuse_decomposition_multiple_solves_jax (fgraph , node ):
256+ return _split_decomp_and_solve_steps (
257+ fgraph , node , eager = False , allowed_assume_a = {"gen" , "pos" }
258+ )
236259
237260
238261optdb ["specialize" ].register (
239- reuse_lu_decomposition_multiple_solves_jax .__name__ ,
240- in2out (reuse_lu_decomposition_multiple_solves_jax , ignore_newtrees = True ),
262+ reuse_decomposition_multiple_solves_jax .__name__ ,
263+ in2out (reuse_decomposition_multiple_solves_jax , ignore_newtrees = True ),
241264 "jax" ,
242265 use_db_name_as_tag = False ,
243266)
244267
245268
246269@node_rewriter ([Scan ])
247- def scan_split_non_sequence_lu_decomposition_solve_jax (fgraph , node ):
248- return _scan_split_non_sequence_lu_decomposition_solve (
249- fgraph , node , allowed_assume_a = {"gen" }
270+ def scan_split_non_sequence_decomposition_and_solve_jax (fgraph , node ):
271+ return _scan_split_non_sequence_decomposition_and_solve (
272+ fgraph , node , allowed_assume_a = {"gen" , "pos" }
250273 )
251274
252275
253276scan_seqopt1 .register (
254- scan_split_non_sequence_lu_decomposition_solve_jax .__name__ ,
255- in2out (scan_split_non_sequence_lu_decomposition_solve_jax , ignore_newtrees = True ),
277+ scan_split_non_sequence_decomposition_and_solve_jax .__name__ ,
278+ in2out (scan_split_non_sequence_decomposition_and_solve_jax , ignore_newtrees = True ),
256279 "jax" ,
257280 use_db_name_as_tag = False ,
258281 position = 2 ,
0 commit comments