Skip to content

Commit 17c675a

Browse files
authored
Remove scalar_ prefix from several Ops (#1683)
* Initial plan * Rename ScalarMaximum/ScalarMinimum to Maximum/Minimum * Apply ruff formatting fixes * Remove custom names when class name matches desired name * Remove scalar_ prefix from log1mexp, xlogx, xlogy0 and fix numba imports * Fix xtensor test to use backward compat aliases in skip list --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
1 parent 3082ed5 commit 17c675a

File tree

14 files changed

+63
-68
lines changed

14 files changed

+63
-68
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
IsNan,
2626
Log,
2727
Log1p,
28+
Maximum,
29+
Minimum,
2830
Mul,
2931
Neg,
3032
Pow,
31-
ScalarMaximum,
32-
ScalarMinimum,
3333
Sign,
3434
Sin,
3535
Sqr,
@@ -105,15 +105,15 @@ def any_reduce(x):
105105
return any_reduce
106106

107107

108-
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
108+
@mlx_funcify_CAReduce_scalar_op.register(Maximum)
109109
def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
110110
def max_reduce(x):
111111
return mx.max(x, axis=axis)
112112

113113
return max_reduce
114114

115115

116-
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
116+
@mlx_funcify_CAReduce_scalar_op.register(Minimum)
117117
def mlx_funcify_CARreduce_Minimum(scalar_op, axis):
118118
def min_reduce(x):
119119
return mx.min(x, axis=axis)
@@ -354,13 +354,13 @@ def mlx_funcify_Elemwise_scalar_OR(scalar_op):
354354
return mx.bitwise_or
355355

356356

357-
@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum)
358-
def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op):
357+
@mlx_funcify_Elemwise_scalar_op.register(Maximum)
358+
def mlx_funcify_Elemwise_scalar_Maximum(scalar_op):
359359
return mx.maximum
360360

361361

362-
@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum)
363-
def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op):
362+
@mlx_funcify_Elemwise_scalar_op.register(Minimum)
363+
def mlx_funcify_Elemwise_scalar_Minimum(scalar_op):
364364
return mx.minimum
365365

366366

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
XOR,
2727
Add,
2828
IntDiv,
29+
Maximum,
30+
Minimum,
2931
Mul,
30-
ScalarMaximum,
31-
ScalarMinimum,
3232
Sub,
3333
TrueDiv,
3434
get_scalar_type,
35-
scalar_maximum,
35+
maximum,
3636
)
3737
from pytensor.scalar.basic import add as add_as
3838
from pytensor.tensor.blas import BatchedDot
@@ -104,16 +104,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr):
104104
return f"{res}[{idx}] //= {arr}"
105105

106106

107-
@scalar_in_place_fn.register(ScalarMaximum)
108-
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
107+
@scalar_in_place_fn.register(Maximum)
108+
def scalar_in_place_fn_Maximum(op, idx, res, arr):
109109
return f"""
110110
if {res}[{idx}] < {arr}:
111111
{res}[{idx}] = {arr}
112112
"""
113113

114114

115-
@scalar_in_place_fn.register(ScalarMinimum)
116-
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
115+
@scalar_in_place_fn.register(Minimum)
116+
def scalar_in_place_fn_Minimum(op, idx, res, arr):
117117
return f"""
118118
if {res}[{idx}] > {arr}:
119119
{res}[{idx}] = {arr}
@@ -459,7 +459,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
459459
if axis is not None:
460460
axis = normalize_axis_index(axis, x_at.ndim)
461461
reduce_max_py = create_multiaxis_reducer(
462-
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
462+
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
463463
)
464464
reduce_sum_py = create_multiaxis_reducer(
465465
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
@@ -523,7 +523,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
523523
if axis is not None:
524524
axis = normalize_axis_index(axis, x_at.ndim)
525525
reduce_max_py = create_multiaxis_reducer(
526-
scalar_maximum,
526+
maximum,
527527
-np.inf,
528528
(axis,),
529529
x_at.ndim,

pytensor/scalar/basic.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ def c_code(self, node, name, inputs, outputs, sub):
18551855
##############
18561856
# Arithmetic
18571857
##############
1858-
class ScalarMaximum(BinaryScalarOp):
1858+
class Maximum(BinaryScalarOp):
18591859
commutative = True
18601860
associative = True
18611861
nfunc_spec = ("maximum", 2, 1)
@@ -1895,10 +1895,14 @@ def L_op(self, inputs, outputs, gout):
18951895
return (gx, gy)
18961896

18971897

1898-
scalar_maximum = ScalarMaximum(upcast_out, name="maximum")
1898+
maximum = Maximum(upcast_out)
18991899

1900+
# Backward compatibility
1901+
ScalarMaximum = Maximum
1902+
scalar_maximum = maximum
19001903

1901-
class ScalarMinimum(BinaryScalarOp):
1904+
1905+
class Minimum(BinaryScalarOp):
19021906
commutative = True
19031907
associative = True
19041908
nfunc_spec = ("minimum", 2, 1)
@@ -1937,7 +1941,11 @@ def L_op(self, inputs, outputs, gout):
19371941
return (gx, gy)
19381942

19391943

1940-
scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
1944+
minimum = Minimum(upcast_out)
1945+
1946+
# Backward compatibility
1947+
ScalarMinimum = Minimum
1948+
scalar_minimum = minimum
19411949

19421950

19431951
class Add(ScalarOp):

pytensor/scalar/math.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
isinf,
3333
log,
3434
log1p,
35+
maximum,
3536
reciprocal,
36-
scalar_maximum,
3737
sqrt,
3838
switch,
3939
true_div,
@@ -1315,7 +1315,7 @@ def c_code_cache_version(self):
13151315
return v
13161316

13171317

1318-
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
1318+
softplus = Softplus(upgrade_to_float)
13191319

13201320

13211321
class Log1mexp(UnaryScalarOp):
@@ -1360,7 +1360,7 @@ def c_code(self, node, name, inp, out, sub):
13601360
raise NotImplementedError("only floating point is implemented")
13611361

13621362

1363-
log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")
1363+
log1mexp = Log1mexp(upgrade_to_float)
13641364

13651365

13661366
class BetaInc(ScalarOp):
@@ -1585,9 +1585,7 @@ def inner_loop(
15851585
derivative_new = K * (F1 * dK + F2)
15861586

15871587
errapx = scalar_abs(derivative - derivative_new)
1588-
d_errapx = errapx / scalar_maximum(
1589-
err_threshold, scalar_abs(derivative_new)
1590-
)
1588+
d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new))
15911589

15921590
min_iters_cond = n > (min_iters - 1)
15931591
derivative = switch(
@@ -1833,7 +1831,7 @@ def inner_loop(*args):
18331831
if len(grad_incs) == 1:
18341832
[max_abs_grad_inc] = grad_incs
18351833
else:
1836-
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
1834+
max_abs_grad_inc = reduce(maximum, abs_grad_incs)
18371835

18381836
return (
18391837
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),

pytensor/tensor/blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,8 @@ def infer_shape(self, fgraph, node, input_shapes):
948948
z_shape, _, x_shape, y_shape, _ = input_shapes
949949
return [
950950
(
951-
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
952-
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
951+
pytensor.scalar.maximum(z_shape[0], x_shape[0]),
952+
pytensor.scalar.maximum(z_shape[1], y_shape[1]),
953953
)
954954
]
955955

pytensor/tensor/inplace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,12 @@ def second_inplace(a):
357357
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
358358

359359

360-
@scalar_elemwise(symbolname="scalar_maximum_inplace")
360+
@scalar_elemwise
361361
def maximum_inplace(a, b):
362362
"""elementwise addition (inplace on `a`)"""
363363

364364

365-
@scalar_elemwise(symbolname="scalar_minimum_inplace")
365+
@scalar_elemwise
366366
def minimum_inplace(a, b):
367367
"""elementwise addition (inplace on `a`)"""
368368

pytensor/tensor/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class Max(NonZeroDimsCAReduce):
399399
nfunc_spec = ("max", 1, 1)
400400

401401
def __init__(self, axis):
402-
super().__init__(ps.scalar_maximum, axis)
402+
super().__init__(ps.maximum, axis)
403403

404404
def clone(self, **kwargs):
405405
axis = kwargs.get("axis", self.axis)
@@ -457,7 +457,7 @@ class Min(NonZeroDimsCAReduce):
457457
nfunc_spec = ("min", 1, 1)
458458

459459
def __init__(self, axis):
460-
super().__init__(ps.scalar_minimum, axis)
460+
super().__init__(ps.minimum, axis)
461461

462462
def clone(self, **kwargs):
463463
axis = kwargs.get("axis", self.axis)
@@ -2755,7 +2755,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
27552755
return ifelse(even_k, even_median, odd_median, name="median")
27562756

27572757

2758-
@scalar_elemwise(symbolname="scalar_maximum")
2758+
@scalar_elemwise
27592759
def maximum(x, y):
27602760
"""elemwise maximum. See max for the maximum in one tensor
27612761
@@ -2791,7 +2791,7 @@ def maximum(x, y):
27912791
# see decorator for function body
27922792

27932793

2794-
@scalar_elemwise(symbolname="scalar_minimum")
2794+
@scalar_elemwise
27952795
def minimum(x, y):
27962796
"""elemwise minimum. See min for the minimum in one tensor
27972797

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node):
6060
if (
6161
max.owner
6262
and isinstance(max.owner.op, CAReduce)
63-
and max.owner.op.scalar_op == ps.scalar_maximum
63+
and max.owner.op.scalar_op == ps.maximum
6464
):
6565
neg_node = max.owner.inputs[0]
6666
if neg_node.owner and neg_node.owner.op == neg:

pytensor/tensor/xlogx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def c_code(self, node, name, inputs, outputs, sub):
3131
raise NotImplementedError("only floatingpoint is implemented")
3232

3333

34-
scalar_xlogx = XlogX(ps.upgrade_to_float, name="scalar_xlogx")
34+
scalar_xlogx = XlogX(ps.upgrade_to_float)
3535
xlogx = Elemwise(scalar_xlogx, name="xlogx")
3636

3737

@@ -62,5 +62,5 @@ def c_code(self, node, name, inputs, outputs, sub):
6262
raise NotImplementedError("only floatingpoint is implemented")
6363

6464

65-
scalar_xlogy0 = XlogY0(ps.upgrade_to_float, name="scalar_xlogy0")
65+
scalar_xlogy0 = XlogY0(ps.upgrade_to_float)
6666
xlogy0 = Elemwise(scalar_xlogy0, name="xlogy0")

pytensor/xtensor/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ def reciprocal(): ...
388388
def round(): ...
389389

390390

391-
@_as_xelemwise(ps.scalar_maximum)
391+
@_as_xelemwise(ps.maximum)
392392
def maximum(): ...
393393

394394

395-
@_as_xelemwise(ps.scalar_minimum)
395+
@_as_xelemwise(ps.minimum)
396396
def minimum(): ...
397397

398398

0 commit comments

Comments
 (0)