|
26 | 26 | XOR, |
27 | 27 | Add, |
28 | 28 | IntDiv, |
| 29 | + Maximum, |
| 30 | + Minimum, |
29 | 31 | Mul, |
30 | | - ScalarMaximum, |
31 | | - ScalarMinimum, |
32 | 32 | Sub, |
33 | 33 | TrueDiv, |
34 | 34 | get_scalar_type, |
35 | | - scalar_maximum, |
| 35 | + maximum, |
36 | 36 | ) |
37 | 37 | from pytensor.scalar.basic import add as add_as |
38 | 38 | from pytensor.tensor.blas import BatchedDot |
@@ -104,16 +104,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr): |
104 | 104 | return f"{res}[{idx}] //= {arr}" |
105 | 105 |
|
106 | 106 |
|
107 | | -@scalar_in_place_fn.register(ScalarMaximum) |
108 | | -def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr): |
| 107 | +@scalar_in_place_fn.register(Maximum) |
| 108 | +def scalar_in_place_fn_Maximum(op, idx, res, arr): |
109 | 109 | return f""" |
110 | 110 | if {res}[{idx}] < {arr}: |
111 | 111 | {res}[{idx}] = {arr} |
112 | 112 | """ |
113 | 113 |
|
114 | 114 |
|
115 | | -@scalar_in_place_fn.register(ScalarMinimum) |
116 | | -def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): |
| 115 | +@scalar_in_place_fn.register(Minimum) |
| 116 | +def scalar_in_place_fn_Minimum(op, idx, res, arr): |
117 | 117 | return f""" |
118 | 118 | if {res}[{idx}] > {arr}: |
119 | 119 | {res}[{idx}] = {arr} |
@@ -459,7 +459,7 @@ def numba_funcify_Softmax(op, node, **kwargs): |
459 | 459 | if axis is not None: |
460 | 460 | axis = normalize_axis_index(axis, x_at.ndim) |
461 | 461 | reduce_max_py = create_multiaxis_reducer( |
462 | | - scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
| 462 | + maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True |
463 | 463 | ) |
464 | 464 | reduce_sum_py = create_multiaxis_reducer( |
465 | 465 | add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True |
@@ -523,7 +523,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): |
523 | 523 | if axis is not None: |
524 | 524 | axis = normalize_axis_index(axis, x_at.ndim) |
525 | 525 | reduce_max_py = create_multiaxis_reducer( |
526 | | - scalar_maximum, |
| 526 | + maximum, |
527 | 527 | -np.inf, |
528 | 528 | (axis,), |
529 | 529 | x_at.ndim, |
|
0 commit comments