diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index ce14d08246..5c909eb225 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,8 @@ from copy import deepcopy +import mlx.core as mx +import numpy as np + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -13,12 +16,52 @@ from pytensor.tensor.type_other import MakeSlice +def normalize_indices_for_mlx(indices): + """Convert numpy integers to Python integers for MLX indexing. + + MLX requires index values to be Python int, not np.int64 or other NumPy types. + """ + + def to_int(value, element): + """Convert value to Python int with helpful error message.""" + try: + return int(value) + except (TypeError, ValueError) as e: + raise TypeError( + "MLX backend does not support symbolic indices. " + "Index values must be concrete (constant) integers, not symbolic variables. " + f"Got: {element}" + ) from e + + def normalize_element(element): + if element is None: + return None + elif isinstance(element, slice): + return slice( + normalize_element(element.start), + normalize_element(element.stop), + normalize_element(element.step), + ) + elif isinstance(element, mx.array) and element.ndim == 0: + return to_int(element.item(), element) + elif isinstance(element, np.integer): + return to_int(element, element) + else: + return element + + return tuple(normalize_element(idx) for idx in indices) + + @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + """MLX implementation of Subtensor.""" + idx_list = op.idx_list def subtensor(x, *ilists): - indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + # Convert ilist to indices using idx_list (basic subtensor) + indices = indices_from_subtensor(ilists, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(indices) if len(indices) == 1: indices = indices[0] @@ -30,10 +73,12 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + """MLX implementation of AdvancedSubtensor.""" def advanced_subtensor(x, *ilists): - indices = indices_from_subtensor(ilists, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + # Advanced indexing doesn't use idx_list or indices_from_subtensor + indices = normalize_indices_for_mlx(ilists) if len(indices) == 1: indices = indices[0] @@ -43,11 +88,11 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + """MLX implementation of IncSubtensor.""" + idx_list = op.idx_list - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -64,7 +109,11 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + # Convert ilist to indices using idx_list (basic inc_subtensor) indices = indices_from_subtensor(ilist, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(indices) + if len(indices) == 1: indices = indices[0] @@ -74,8 +123,11 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + """MLX implementation of AdvancedIncSubtensor.""" + + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -92,7 +144,15 @@ def mlx_fn(x, indices, y): return x def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): - return mlx_fn(x, ilist, y) + # Normalize indices to handle np.int64 and other NumPy types + # Advanced indexing doesn't use idx_list or indices_from_subtensor + indices = normalize_indices_for_mlx(ilist) + + # For advanced indexing, if we have a single tuple of indices, unwrap it + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) return advancedincsubtensor diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2923411799..cc4c108956 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment(): assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) + # Increment slice + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt) + compare_mlx_and_py([], [out_pt], []) + def test_mlx_AdvancedIncSubtensor_set(): """Test advanced set operations using AdvancedIncSubtensor.""" @@ -232,9 +245,12 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): - """Test subtensor operations with PyTensor variables as inputs.""" + """Test subtensor operations with PyTensor variables as inputs. + + This test now works thanks to the fix for np.int64 indexing, which also + handles the conversion of MLX scalar arrays in slice components. + """ # Test with variable arrays (not constants) x_pt = pt.matrix("x", dtype="float32") y_pt = pt.vector("y", dtype="float32") @@ -245,3 +261,153 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_subtensor_with_numpy_int64(): + """Test Subtensor operations with np.int64 indices. + + This tests the fix for MLX's strict requirement that indices must be + Python int, not np.int64 or other NumPy integer types. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Single np.int64 index - this was failing before the fix + idx = np.int64(1) + out_pt = x_pt[idx] + compare_mlx_and_py([], [out_pt], []) + + # Multiple np.int64 indices + out_pt = x_pt[np.int64(1), np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 index + out_pt = x_pt[np.int64(-1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed Python int and np.int64 + out_pt = x_pt[1, np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_subtensor_slices_with_numpy_int64(): + """Test Subtensor with slices containing np.int64 components. + + This tests that slice start/stop/step values can be np.int64. + """ + x_np = np.arange(20, dtype=np.float32) + x_pt = pt.constant(x_np) + + # Slice with np.int64 start + out_pt = x_pt[np.int64(2) :] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 stop + out_pt = x_pt[: np.int64(5)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 start and stop + out_pt = x_pt[np.int64(2) : np.int64(8)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 step + out_pt = x_pt[:: np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with all np.int64 components + out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 in slice + out_pt = x_pt[np.int64(-5) :] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_with_numpy_int64(): + """Test IncSubtensor (set/inc) with np.int64 indices and slices. + + This is the main test for the reported issue with inc_subtensor. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) + + # Set with np.int64 index + out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with np.int64 index + out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Complex slice with np.int64 + y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor( + x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt + ) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_original_issue(): + """Test the exact example from the issue report. + + This was failing with: ValueError: Slice indices must be integers or None. + """ + x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) + x_pt = pt.constant(x_np, dtype="float64") + + # The exact failing case from the issue + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + # Verify it also works with set_subtensor + out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_subtensor_with_numpy_int64(): + """Test AdvancedSubtensor with np.int64 in mixed indexing.""" + x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) + x_pt = pt.constant(x_np) + + # Advanced indexing with list, but other dimensions use np.int64 + # Note: This creates AdvancedSubtensor, not basic Subtensor + out_pt = x_pt[[0, 2], np.int64(1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed advanced and basic indexing with np.int64 in slice + out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_incsubtensor_with_numpy_int64(): + """Test AdvancedIncSubtensor with np.int64.""" + x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) + x_pt = pt.constant(x_np) + + # Value to set/increment - using 4 rows now for vector indexing + y_pt = pt.as_tensor_variable( + np.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=np.float32, + ) + ) + + # Advanced indexing set with vector array indices + indices = np.array([0, 1, 2, 3], dtype=np.int64) + out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Advanced indexing increment + out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], [])