diff --git a/pytensor/gradient.py b/pytensor/gradient.py index ecdf4fbd4c..81fb4aeff5 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -494,22 +494,25 @@ def Lop( coordinates of the tensor elements. If `f` is a list/tuple, then return a list/tuple with the results. """ + from pytensor.tensor import as_tensor_variable + if not isinstance(eval_points, list | tuple): - _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] - else: - _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] + eval_points = [eval_points] + _eval_points = [ + x if isinstance(x, Variable) else as_tensor_variable(x) for x in eval_points + ] if not isinstance(f, list | tuple): - _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] - else: - _f = [pytensor.tensor.as_tensor_variable(x) for x in f] + f = [f] + _f = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in f] grads = list(_eval_points) + using_list = isinstance(wrt, list) + using_tuple = isinstance(wrt, tuple) if not isinstance(wrt, list | tuple): - _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] - else: - _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] + wrt = [wrt] + _wrt = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in wrt] assert len(_f) == len(grads) known = dict(zip(_f, grads, strict=True)) @@ -523,8 +526,6 @@ def Lop( return_disconnected=return_disconnected, ) - using_list = isinstance(wrt, list) - using_tuple = isinstance(wrt, tuple) return as_list_or_tuple(using_list, using_tuple, ret) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index a79746da6d..34e5d6b730 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -11,6 +11,7 @@ DisconnectedType, GradClip, GradScale, + Lop, NullTypeGradError, Rop, UndefinedGrad, @@ -32,6 +33,7 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.traversal import graph_inputs +from pytensor.scalar import float64 from pytensor.scan.op import Scan from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import sum as pt_sum @@ -1207,3 +1209,13 @@ def test_multiple_wrt(self): hessp_x_eval, hessp_y_eval = hessp_fn(**test) np.testing.assert_allclose(hessp_x_eval, [2, 4, 6]) np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2]) + + +def test_scalar_Lop(): + xtm1 = float64("xtm1") + xt = xtm1**2 + + dout_dxt = float64("dout_dxt") + dout_dxtm1 = Lop(xt, wrt=xtm1, eval_points=dout_dxt) + assert dout_dxtm1.type == dout_dxt.type + assert dout_dxtm1.eval({xtm1: 3.0, dout_dxt: 1.5}) == 2 * 3.0 * 1.5