From bff36855ea769a1738d528b2c5f00c0814e7d070 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 23 Oct 2025 21:01:33 -0500 Subject: [PATCH 1/9] Handle slices in `mlx_funcify_IncSubtensor` --- pytensor/link/mlx/dispatch/subtensor.py | 20 +++++++++++++++++++- tests/link/mlx/test_subtensor.py | 13 +++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index ce14d08246..e37f05fbbd 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -64,7 +64,25 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): - indices = indices_from_subtensor(ilist, idx_list) + def get_slice_int(element): + if element is None: + return None + try: + return int(element) + except Exception: + return element + + indices = tuple( + [ + slice( + get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step) + ) + if isinstance(s, slice) + else s + for s in indices_from_subtensor(ilist, idx_list) + ] + ) + if len(indices) == 1: indices = indices[0] diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2923411799..a13960807e 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment(): assert not out_pt.owner.op.set_instead_of_inc compare_mlx_and_py([], [out_pt], []) + # Increment slice + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt) + compare_mlx_and_py([], [out_pt], []) + + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt) + compare_mlx_and_py([], [out_pt], []) + def test_mlx_AdvancedIncSubtensor_set(): """Test advanced set operations using AdvancedIncSubtensor.""" From a63e759b5380b86ffbaa11bfeb3603bf1c34b806 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 20:55:46 +0200 Subject: [PATCH 2/9] Create test_mlx_indexing_behavior.py --- tests/link/mlx/test_mlx_indexing_behavior.py | 212 +++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 tests/link/mlx/test_mlx_indexing_behavior.py diff --git a/tests/link/mlx/test_mlx_indexing_behavior.py b/tests/link/mlx/test_mlx_indexing_behavior.py new file mode 100644 index 0000000000..8021ae6126 --- /dev/null +++ b/tests/link/mlx/test_mlx_indexing_behavior.py @@ -0,0 +1,212 @@ +""" +Pure MLX indexing behavior tests. + +This module tests MLX's indexing capabilities with different index types +to understand what conversions are needed for PyTensor compatibility. +""" + +import numpy as np +import pytest + + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_python_int_indexing(): + """Test that MLX accepts Python int for indexing.""" + x = mx.array([1, 2, 3, 4, 5]) + + # Single index + result = x[2] + assert result == 3 + + # Slice with Python int + result = x[1:4] + assert list(result) == [2, 3, 4] + + # Slice with step :) + result = x[0:5:2] + assert list(result) == [1, 3, 5] + + +def test_mlx_numpy_int64_single_index(): + """Test MLX behavior with np.int64 single index.""" + x = mx.array([1, 2, 3, 4, 5]) + + # This should fail with MLX + with pytest.raises(ValueError, match="Cannot index mlx array"): + _ = x[np.int64(2)] + + +def test_mlx_numpy_int64_in_slice(): + """Test MLX behavior with np.int64 in slice components.""" + x = mx.array([1, 2, 3, 4, 5]) + + # Slice with np.int64 start + with pytest.raises(ValueError, match="Slice indices must be integers or None"): + _ = x[np.int64(1) : 4] + + # Slice with np.int64 stop + with pytest.raises(ValueError, match="Slice indices must be integers or None"): + _ = x[1 : np.int64(4)] + + # Slice with np.int64 step + with pytest.raises(ValueError, match="Slice indices must be integers or None"): + _ = x[0 : 5 : np.int64(2)] + + +def test_mlx_conversion_int64_to_python_int(): + """Test that converting np.int64 to Python int works for MLX indexing.""" + x = mx.array([1, 2, 3, 4, 5]) + + # Convert np.int64 to Python int + idx = int(np.int64(2)) + result = x[idx] + assert result == 3 + + # Convert in slice + start = int(np.int64(1)) + stop = int(np.int64(4)) + step = int(np.int64(2)) + result = x[start:stop:step] + assert list(result) == [2, 4] + + +def test_mlx_slice_with_none(): + """Test that MLX accepts None in slice components.""" + x = mx.array([1, 2, 3, 4, 5]) + + # Slice with None + result = x[None:3] + assert list(result) == [1, 2, 3] + + result = x[2:None] + assert list(result) == [3, 4, 5] + + result = x[None:None:2] + assert list(result) == [1, 3, 5] + + +def test_mlx_multidimensional_indexing(): + """Test MLX indexing with multidimensional arrays.""" + x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + # Python int indexing works + result = x[1, 2] + assert result == 6 + + # Mixed slice and int + result = x[1, :] + assert list(result) == [4, 5, 6] + + result = x[:, 1] + assert list(result) == [2, 5, 8] + + # Multiple slices + result = x[0:2, 1:3] + expected = [[2, 3], [5, 6]] + assert result.tolist() == expected + + +def test_mlx_negative_indices(): + """Test MLX with negative indices (both Python int and np.int64).""" + x = mx.array([1, 2, 3, 4, 5]) + + # Negative Python int works + result = x[-1] + assert result == 5 + + result = x[-3:-1] + assert list(result) == [3, 4] + + # Negative np.int64 should fail + with pytest.raises(ValueError, match="Cannot index mlx array"): + _ = x[np.int64(-1)] + + # But converting to Python int works + idx = int(np.int64(-1)) + result = x[idx] + assert result == 5 + + +def test_mlx_array_indexing(): + """Test MLX with array indices (advanced indexing).""" + x = mx.array([1, 2, 3, 4, 5]) + + # Array indexing with MLX array works + indices = mx.array([0, 2, 4]) + result = x[indices] + assert list(result) == [1, 3, 5] + + # Array indexing with NumPy array should fail + indices = np.array([0, 2, 4]) + with pytest.raises(ValueError, match="Cannot index mlx array"): + _ = x[indices] + + # But converting NumPy array to MLX array works + indices_mlx = mx.array(indices) + result = x[indices_mlx] + assert list(result) == [1, 3, 5] + + +def test_conversion_helper_behavior(): + """Test the behavior of our proposed int conversion helper.""" + + def get_slice_int(element): + """Helper to convert slice components to Python int.""" + if element is None: + return None + try: + return int(element) + except Exception: + return element + + # Test with None + assert get_slice_int(None) is None + + # Test with Python int + assert get_slice_int(5) == 5 + assert isinstance(get_slice_int(5), int) + + # Test with np.int64 + assert get_slice_int(np.int64(5)) == 5 + assert isinstance(get_slice_int(np.int64(5)), int) + + # Test with np.int32 + assert get_slice_int(np.int32(5)) == 5 + assert isinstance(get_slice_int(np.int32(5)), int) + + # Test with array (should pass through) + arr = np.array([1, 2, 3]) + result = get_slice_int(arr) + assert np.array_equal(result, arr) + + +def test_mlx_indexing_with_converted_slices(): + """Test that MLX indexing works after converting slice components.""" + x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + def normalize_slice(s): + """Convert slice components to Python int.""" + if not isinstance(s, slice): + # For non-slice indices, try to convert to int + try: + return int(s) + except (TypeError, ValueError): + return s + + return slice( + int(s.start) if s.start is not None else None, + int(s.stop) if s.stop is not None else None, + int(s.step) if s.step is not None else None, + ) + + # Create slices with np.int64 + slice1 = slice(np.int64(0), np.int64(2), None) + slice2 = slice(None, np.int64(2), None) + + # Convert and use + normalized = (normalize_slice(slice1), normalize_slice(slice2)) + result = x[normalized] + expected = [[1, 2], [4, 5]] + assert result.tolist() == expected From 119e7e6d467159392a463832c276f2b2051dda13 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:06:01 +0200 Subject: [PATCH 3/9] Fix MLX indexing to support np.int64 and NumPy scalars Introduces normalize_indices_for_mlx to convert NumPy integer and floating types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX subtensor dispatch functions to use this normalization, resolving issues with MLX's strict indexing requirements. Adds comprehensive tests for np.int64 indices and slices in subtensor and inc_subtensor operations, including advanced indexing scenarios. --- pytensor/link/mlx/dispatch/subtensor.py | 164 ++++++++++++++++++++---- tests/link/mlx/test_subtensor.py | 154 +++++++++++++++++++++- 2 files changed, 293 insertions(+), 25 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index e37f05fbbd..d695cfc9ce 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,7 @@ from copy import deepcopy +import numpy as np + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -13,12 +15,119 @@ from pytensor.tensor.type_other import MakeSlice +def normalize_indices_for_mlx(ilist, idx_list): + """Convert indices to MLX-compatible format. + + MLX has strict requirements for indexing: + - Integer indices must be Python int, not np.int64 or other NumPy integer types + - Slice components (start, stop, step) must be Python int or None, not np.int64 + - MLX arrays created from scalars need to be converted back to Python int + - Array indices for advanced indexing are handled separately + + This function converts all integer-like indices and slice components to Python int + while preserving None values and passing through array indices unchanged. + + Parameters + ---------- + ilist : tuple + Runtime index values to be passed to indices_from_subtensor + idx_list : tuple + Static index specification from the Op's idx_list attribute + + Returns + ------- + tuple + Normalized indices compatible with MLX array indexing + + Examples + -------- + >>> # Single np.int64 index converted to Python int + >>> normalize_indices_for_mlx((np.int64(1),), (True,)) + (1,) + + >>> # Slice with np.int64 components + >>> indices = indices_from_subtensor( + ... (np.int64(0), np.int64(2)), (slice(None, None),) + ... ) + >>> # After normalization, slice components are Python int + + Notes + ----- + This conversion is necessary because MLX's C++ indexing implementation + does not recognize NumPy scalar types, raising ValueError when encountered. + Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also + need to be converted back to Python int for use in indexing operations. + Converting to Python int is zero-cost for Python int inputs and minimal + overhead for NumPy scalars and MLX scalar arrays. + """ + import mlx.core as mx + + def normalize_element(element): + """Convert a single index element to MLX-compatible format.""" + if element is None: + # None is valid in slices (e.g., x[None:5] or x[:None]) + return None + elif isinstance(element, slice): + # Recursively normalize slice components + return slice( + normalize_element(element.start), + normalize_element(element.stop), + normalize_element(element.step), + ) + elif isinstance(element, mx.array): + # MLX arrays from mlx_typify need special handling + # If it's a 0-d array (scalar), convert to Python int/float + if element.ndim == 0: + # Extract the scalar value + item = element.item() + # Convert to Python int if it's an integer type + if element.dtype in ( + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + ): + return int(item) + else: + return float(item) + else: + # Multi-dimensional array for advanced indexing - pass through + return element + elif isinstance(element, (np.integer, np.floating)): + # Convert NumPy scalar to Python int/float + # This handles np.int64, np.int32, np.float64, etc. + return int(element) if isinstance(element, np.integer) else float(element) + elif isinstance(element, (int, float)): + # Python int/float are already compatible + return element + else: + # Pass through other types (arrays for advanced indexing, etc.) + return element + + # Get indices from PyTensor's subtensor utility + raw_indices = indices_from_subtensor(ilist, idx_list) + + # Normalize each index element + normalized = tuple(normalize_element(idx) for idx in raw_indices) + + return normalized + + @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): + """MLX implementation of Subtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + """ idx_list = getattr(op, "idx_list", None) def subtensor(x, *ilists): - indices = indices_from_subtensor([int(element) for element in ilists], idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -30,10 +139,16 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX, + including handling np.int64 in mixed basic/advanced indexing scenarios. + """ idx_list = getattr(op, "idx_list", None) def advanced_subtensor(x, *ilists): - indices = indices_from_subtensor(ilists, idx_list) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilists, idx_list) if len(indices) == 1: indices = indices[0] @@ -45,6 +160,11 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) @mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): + """MLX implementation of IncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Handles both set_instead_of_inc=True (assignment) and False (increment). + """ idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): @@ -64,24 +184,8 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): - def get_slice_int(element): - if element is None: - return None - try: - return int(element) - except Exception: - return element - - indices = tuple( - [ - slice( - get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step) - ) - if isinstance(s, slice) - else s - for s in indices_from_subtensor(ilist, idx_list) - ] - ) + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) if len(indices) == 1: indices = indices[0] @@ -93,6 +197,13 @@ def get_slice_int(element): @mlx_funcify.register(AdvancedIncSubtensor) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): + """MLX implementation of AdvancedIncSubtensor operation. + + Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. + Note: For advanced indexing, ilist contains the actual array indices. + """ + idx_list = getattr(op, "idx_list", None) + if getattr(op, "set_instead_of_inc", False): def mlx_fn(x, indices, y): @@ -109,8 +220,15 @@ def mlx_fn(x, indices, y): x[indices] += y return x - def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): - return mlx_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + # Normalize indices to handle np.int64 and other NumPy types + indices = normalize_indices_for_mlx(ilist, idx_list) + + # For advanced indexing, if we have a single tuple of indices, unwrap it + if len(indices) == 1: + indices = indices[0] + + return mlx_fn(x, indices, y) return advancedincsubtensor @@ -120,4 +238,4 @@ def mlx_funcify_MakeSlice(op, **kwargs): def makeslice(*x): return slice(*x) - return makeslice + return makeslice \ No newline at end of file diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index a13960807e..d32954d46c 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -245,9 +245,12 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): - """Test subtensor operations with PyTensor variables as inputs.""" + """Test subtensor operations with PyTensor variables as inputs. + + This test now works thanks to the fix for np.int64 indexing, which also + handles the conversion of MLX scalar arrays in slice components. + """ # Test with variable arrays (not constants) x_pt = pt.matrix("x", dtype="float32") y_pt = pt.vector("y", dtype="float32") @@ -258,3 +261,150 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_subtensor_with_numpy_int64(): + """Test Subtensor operations with np.int64 indices. + + This tests the fix for MLX's strict requirement that indices must be + Python int, not np.int64 or other NumPy integer types. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + + # Single np.int64 index - this was failing before the fix + idx = np.int64(1) + out_pt = x_pt[idx] + compare_mlx_and_py([], [out_pt], []) + + # Multiple np.int64 indices + out_pt = x_pt[np.int64(1), np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 index + out_pt = x_pt[np.int64(-1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed Python int and np.int64 + out_pt = x_pt[1, np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_subtensor_slices_with_numpy_int64(): + """Test Subtensor with slices containing np.int64 components. + + This tests that slice start/stop/step values can be np.int64. + """ + x_np = np.arange(20, dtype=np.float32) + x_pt = pt.constant(x_np) + + # Slice with np.int64 start + out_pt = x_pt[np.int64(2) :] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 stop + out_pt = x_pt[: np.int64(5)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 start and stop + out_pt = x_pt[np.int64(2) : np.int64(8)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with np.int64 step + out_pt = x_pt[:: np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Slice with all np.int64 components + out_pt = x_pt[np.int64(1) : np.int64(10) : np.int64(2)] + compare_mlx_and_py([], [out_pt], []) + + # Negative np.int64 in slice + out_pt = x_pt[np.int64(-5) :] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_with_numpy_int64(): + """Test IncSubtensor (set/inc) with np.int64 indices and slices. + + This is the main test for the reported issue with inc_subtensor. + """ + # Test data + x_np = np.arange(12, dtype=np.float32).reshape((3, 4)) + x_pt = pt.constant(x_np) + y_pt = pt.as_tensor_variable(np.array(10.0, dtype=np.float32)) + + # Set with np.int64 index + out_pt = pt_subtensor.set_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with np.int64 index + out_pt = pt_subtensor.inc_subtensor(x_pt[np.int64(1), np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Set with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.set_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Increment with slice containing np.int64 - THE ORIGINAL FAILING CASE + out_pt = pt_subtensor.inc_subtensor(x_pt[:, : np.int64(2)], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Complex slice with np.int64 + y2_pt = pt.as_tensor_variable(np.ones((2, 2), dtype=np.float32)) + out_pt = pt_subtensor.inc_subtensor( + x_pt[np.int64(0) : np.int64(2), np.int64(1) : np.int64(3)], y2_pt + ) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_incsubtensor_original_issue(): + """Test the exact example from the issue report. + + This was failing with: ValueError: Slice indices must be integers or None. + """ + x_np = np.arange(9, dtype=np.float64).reshape((3, 3)) + x_pt = pt.constant(x_np, dtype="float64") + + # The exact failing case from the issue + out_pt = pt_subtensor.inc_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + # Verify it also works with set_subtensor + out_pt = pt_subtensor.set_subtensor(x_pt[:, :2], 10) + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_subtensor_with_numpy_int64(): + """Test AdvancedSubtensor with np.int64 in mixed indexing.""" + x_np = np.arange(24, dtype=np.float32).reshape((3, 4, 2)) + x_pt = pt.constant(x_np) + + # Advanced indexing with list, but other dimensions use np.int64 + # Note: This creates AdvancedSubtensor, not basic Subtensor + out_pt = x_pt[[0, 2], np.int64(1)] + compare_mlx_and_py([], [out_pt], []) + + # Mixed advanced and basic indexing with np.int64 in slice + out_pt = x_pt[[0, 2], np.int64(1) : np.int64(3)] + compare_mlx_and_py([], [out_pt], []) + + +def test_mlx_advanced_incsubtensor_with_numpy_int64(): + """Test AdvancedIncSubtensor with np.int64.""" + x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) + x_pt = pt.constant(x_np) + + # Value to set/increment + y_pt = pt.as_tensor_variable( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + ) + + # Advanced indexing set with array indices + indices = [np.int64(0), np.int64(2)] + out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) + + # Advanced indexing increment + out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) + compare_mlx_and_py([], [out_pt], []) \ No newline at end of file From 116a1bdf2ded1402db3b55ae100cc998ce8e9583 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:08:08 +0200 Subject: [PATCH 4/9] Add missing newline at end of files Appended a newline to the end of subtensor.py and test_subtensor.py to conform with POSIX standards and improve code consistency. --- pytensor/link/mlx/dispatch/subtensor.py | 2 +- tests/link/mlx/test_subtensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index d695cfc9ce..422d566f11 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -238,4 +238,4 @@ def mlx_funcify_MakeSlice(op, **kwargs): def makeslice(*x): return slice(*x) - return makeslice \ No newline at end of file + return makeslice diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index d32954d46c..3fa233fd57 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -407,4 +407,4 @@ def test_mlx_advanced_incsubtensor_with_numpy_int64(): # Advanced indexing increment out_pt = pt_subtensor.inc_subtensor(x_pt[indices], y_pt) - compare_mlx_and_py([], [out_pt], []) \ No newline at end of file + compare_mlx_and_py([], [out_pt], []) From 4f7ae9f31444d40912ea2c346fc5c16ec8a37499 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:50:07 +0200 Subject: [PATCH 5/9] modify --- pytensor/link/mlx/dispatch/subtensor.py | 16 +- tests/link/mlx/test_mlx_indexing_behavior.py | 212 ------------------- 2 files changed, 7 insertions(+), 221 deletions(-) delete mode 100644 tests/link/mlx/test_mlx_indexing_behavior.py diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 422d566f11..fca5bb8985 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -27,6 +27,13 @@ def normalize_indices_for_mlx(ilist, idx_list): This function converts all integer-like indices and slice components to Python int while preserving None values and passing through array indices unchanged. + This conversion is necessary because MLX's C++ indexing implementation + does not recognize NumPy scalar types, raising ValueError when encountered. + Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also + need to be converted back to Python int for use in indexing operations. + Converting to Python int is zero-cost for Python int inputs and minimal + overhead for NumPy scalars and MLX scalar arrays. + Parameters ---------- ilist : tuple @@ -50,15 +57,6 @@ def normalize_indices_for_mlx(ilist, idx_list): ... (np.int64(0), np.int64(2)), (slice(None, None),) ... ) >>> # After normalization, slice components are Python int - - Notes - ----- - This conversion is necessary because MLX's C++ indexing implementation - does not recognize NumPy scalar types, raising ValueError when encountered. - Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also - need to be converted back to Python int for use in indexing operations. - Converting to Python int is zero-cost for Python int inputs and minimal - overhead for NumPy scalars and MLX scalar arrays. """ import mlx.core as mx diff --git a/tests/link/mlx/test_mlx_indexing_behavior.py b/tests/link/mlx/test_mlx_indexing_behavior.py deleted file mode 100644 index 8021ae6126..0000000000 --- a/tests/link/mlx/test_mlx_indexing_behavior.py +++ /dev/null @@ -1,212 +0,0 @@ -""" -Pure MLX indexing behavior tests. - -This module tests MLX's indexing capabilities with different index types -to understand what conversions are needed for PyTensor compatibility. -""" - -import numpy as np -import pytest - - -mx = pytest.importorskip("mlx.core") - - -def test_mlx_python_int_indexing(): - """Test that MLX accepts Python int for indexing.""" - x = mx.array([1, 2, 3, 4, 5]) - - # Single index - result = x[2] - assert result == 3 - - # Slice with Python int - result = x[1:4] - assert list(result) == [2, 3, 4] - - # Slice with step :) - result = x[0:5:2] - assert list(result) == [1, 3, 5] - - -def test_mlx_numpy_int64_single_index(): - """Test MLX behavior with np.int64 single index.""" - x = mx.array([1, 2, 3, 4, 5]) - - # This should fail with MLX - with pytest.raises(ValueError, match="Cannot index mlx array"): - _ = x[np.int64(2)] - - -def test_mlx_numpy_int64_in_slice(): - """Test MLX behavior with np.int64 in slice components.""" - x = mx.array([1, 2, 3, 4, 5]) - - # Slice with np.int64 start - with pytest.raises(ValueError, match="Slice indices must be integers or None"): - _ = x[np.int64(1) : 4] - - # Slice with np.int64 stop - with pytest.raises(ValueError, match="Slice indices must be integers or None"): - _ = x[1 : np.int64(4)] - - # Slice with np.int64 step - with pytest.raises(ValueError, match="Slice indices must be integers or None"): - _ = x[0 : 5 : np.int64(2)] - - -def test_mlx_conversion_int64_to_python_int(): - """Test that converting np.int64 to Python int works for MLX indexing.""" - x = mx.array([1, 2, 3, 4, 5]) - - # Convert np.int64 to Python int - idx = int(np.int64(2)) - result = x[idx] - assert result == 3 - - # Convert in slice - start = int(np.int64(1)) - stop = int(np.int64(4)) - step = int(np.int64(2)) - result = x[start:stop:step] - assert list(result) == [2, 4] - - -def test_mlx_slice_with_none(): - """Test that MLX accepts None in slice components.""" - x = mx.array([1, 2, 3, 4, 5]) - - # Slice with None - result = x[None:3] - assert list(result) == [1, 2, 3] - - result = x[2:None] - assert list(result) == [3, 4, 5] - - result = x[None:None:2] - assert list(result) == [1, 3, 5] - - -def test_mlx_multidimensional_indexing(): - """Test MLX indexing with multidimensional arrays.""" - x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - - # Python int indexing works - result = x[1, 2] - assert result == 6 - - # Mixed slice and int - result = x[1, :] - assert list(result) == [4, 5, 6] - - result = x[:, 1] - assert list(result) == [2, 5, 8] - - # Multiple slices - result = x[0:2, 1:3] - expected = [[2, 3], [5, 6]] - assert result.tolist() == expected - - -def test_mlx_negative_indices(): - """Test MLX with negative indices (both Python int and np.int64).""" - x = mx.array([1, 2, 3, 4, 5]) - - # Negative Python int works - result = x[-1] - assert result == 5 - - result = x[-3:-1] - assert list(result) == [3, 4] - - # Negative np.int64 should fail - with pytest.raises(ValueError, match="Cannot index mlx array"): - _ = x[np.int64(-1)] - - # But converting to Python int works - idx = int(np.int64(-1)) - result = x[idx] - assert result == 5 - - -def test_mlx_array_indexing(): - """Test MLX with array indices (advanced indexing).""" - x = mx.array([1, 2, 3, 4, 5]) - - # Array indexing with MLX array works - indices = mx.array([0, 2, 4]) - result = x[indices] - assert list(result) == [1, 3, 5] - - # Array indexing with NumPy array should fail - indices = np.array([0, 2, 4]) - with pytest.raises(ValueError, match="Cannot index mlx array"): - _ = x[indices] - - # But converting NumPy array to MLX array works - indices_mlx = mx.array(indices) - result = x[indices_mlx] - assert list(result) == [1, 3, 5] - - -def test_conversion_helper_behavior(): - """Test the behavior of our proposed int conversion helper.""" - - def get_slice_int(element): - """Helper to convert slice components to Python int.""" - if element is None: - return None - try: - return int(element) - except Exception: - return element - - # Test with None - assert get_slice_int(None) is None - - # Test with Python int - assert get_slice_int(5) == 5 - assert isinstance(get_slice_int(5), int) - - # Test with np.int64 - assert get_slice_int(np.int64(5)) == 5 - assert isinstance(get_slice_int(np.int64(5)), int) - - # Test with np.int32 - assert get_slice_int(np.int32(5)) == 5 - assert isinstance(get_slice_int(np.int32(5)), int) - - # Test with array (should pass through) - arr = np.array([1, 2, 3]) - result = get_slice_int(arr) - assert np.array_equal(result, arr) - - -def test_mlx_indexing_with_converted_slices(): - """Test that MLX indexing works after converting slice components.""" - x = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - - def normalize_slice(s): - """Convert slice components to Python int.""" - if not isinstance(s, slice): - # For non-slice indices, try to convert to int - try: - return int(s) - except (TypeError, ValueError): - return s - - return slice( - int(s.start) if s.start is not None else None, - int(s.stop) if s.stop is not None else None, - int(s.step) if s.step is not None else None, - ) - - # Create slices with np.int64 - slice1 = slice(np.int64(0), np.int64(2), None) - slice2 = slice(None, np.int64(2), None) - - # Convert and use - normalized = (normalize_slice(slice1), normalize_slice(slice2)) - result = x[normalized] - expected = [[1, 2], [4, 5]] - assert result.tolist() == expected From 9632ad61f2dc05a8ae6b941adda09a0f77382fa7 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:01:03 +0200 Subject: [PATCH 6/9] Simplify based on Ricardo input --- pytensor/link/mlx/dispatch/subtensor.py | 116 +++--------------------- 1 file changed, 14 insertions(+), 102 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index fca5bb8985..074f9d8ad8 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,6 @@ from copy import deepcopy +import mlx.core as mx import numpy as np from pytensor.link.mlx.dispatch.basic import mlx_funcify @@ -16,111 +17,34 @@ def normalize_indices_for_mlx(ilist, idx_list): - """Convert indices to MLX-compatible format. - - MLX has strict requirements for indexing: - - Integer indices must be Python int, not np.int64 or other NumPy integer types - - Slice components (start, stop, step) must be Python int or None, not np.int64 - - MLX arrays created from scalars need to be converted back to Python int - - Array indices for advanced indexing are handled separately - - This function converts all integer-like indices and slice components to Python int - while preserving None values and passing through array indices unchanged. - - This conversion is necessary because MLX's C++ indexing implementation - does not recognize NumPy scalar types, raising ValueError when encountered. - Additionally, mlx_typify converts NumPy scalars to MLX arrays, which also - need to be converted back to Python int for use in indexing operations. - Converting to Python int is zero-cost for Python int inputs and minimal - overhead for NumPy scalars and MLX scalar arrays. - - Parameters - ---------- - ilist : tuple - Runtime index values to be passed to indices_from_subtensor - idx_list : tuple - Static index specification from the Op's idx_list attribute - - Returns - ------- - tuple - Normalized indices compatible with MLX array indexing - - Examples - -------- - >>> # Single np.int64 index converted to Python int - >>> normalize_indices_for_mlx((np.int64(1),), (True,)) - (1,) - - >>> # Slice with np.int64 components - >>> indices = indices_from_subtensor( - ... (np.int64(0), np.int64(2)), (slice(None, None),) - ... ) - >>> # After normalization, slice components are Python int + """Convert numpy integers to Python integers for MLX indexing. + + MLX requires index values to be Python int, not np.int64 or other NumPy types. """ - import mlx.core as mx def normalize_element(element): - """Convert a single index element to MLX-compatible format.""" if element is None: - # None is valid in slices (e.g., x[None:5] or x[:None]) return None elif isinstance(element, slice): - # Recursively normalize slice components return slice( normalize_element(element.start), normalize_element(element.stop), normalize_element(element.step), ) - elif isinstance(element, mx.array): - # MLX arrays from mlx_typify need special handling - # If it's a 0-d array (scalar), convert to Python int/float - if element.ndim == 0: - # Extract the scalar value - item = element.item() - # Convert to Python int if it's an integer type - if element.dtype in ( - mx.int8, - mx.int16, - mx.int32, - mx.int64, - mx.uint8, - mx.uint16, - mx.uint32, - mx.uint64, - ): - return int(item) - else: - return float(item) - else: - # Multi-dimensional array for advanced indexing - pass through - return element - elif isinstance(element, (np.integer, np.floating)): - # Convert NumPy scalar to Python int/float - # This handles np.int64, np.int32, np.float64, etc. - return int(element) if isinstance(element, np.integer) else float(element) - elif isinstance(element, (int, float)): - # Python int/float are already compatible - return element + elif isinstance(element, mx.array) and element.ndim == 0: + return int(element.item()) + elif isinstance(element, np.integer): + return int(element) else: - # Pass through other types (arrays for advanced indexing, etc.) return element - # Get indices from PyTensor's subtensor utility - raw_indices = indices_from_subtensor(ilist, idx_list) - - # Normalize each index element - normalized = tuple(normalize_element(idx) for idx in raw_indices) - - return normalized + indices = indices_from_subtensor(ilist, idx_list) + return tuple(normalize_element(idx) for idx in indices) @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): - """MLX implementation of Subtensor operation. - - Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. - """ + """MLX implementation of Subtensor.""" idx_list = getattr(op, "idx_list", None) def subtensor(x, *ilists): @@ -137,11 +61,7 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): - """MLX implementation of AdvancedSubtensor operation. - - Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX, - including handling np.int64 in mixed basic/advanced indexing scenarios. - """ + """MLX implementation of AdvancedSubtensor.""" idx_list = getattr(op, "idx_list", None) def advanced_subtensor(x, *ilists): @@ -158,11 +78,7 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) @mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - """MLX implementation of IncSubtensor operation. - - Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. - Handles both set_instead_of_inc=True (assignment) and False (increment). - """ + """MLX implementation of IncSubtensor.""" idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): @@ -195,11 +111,7 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): @mlx_funcify.register(AdvancedIncSubtensor) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - """MLX implementation of AdvancedIncSubtensor operation. - - Uses normalize_indices_for_mlx to ensure all indices are compatible with MLX. - Note: For advanced indexing, ilist contains the actual array indices. - """ + """MLX implementation of AdvancedIncSubtensor.""" idx_list = getattr(op, "idx_list", None) if getattr(op, "set_instead_of_inc", False): From c457a13cbbaaf3f41894e50755a098feb4ebb0c4 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:06:30 +0200 Subject: [PATCH 7/9] Update subtensor.py --- pytensor/link/mlx/dispatch/subtensor.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 074f9d8ad8..12c3e91617 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -32,9 +32,23 @@ def normalize_element(element): normalize_element(element.step), ) elif isinstance(element, mx.array) and element.ndim == 0: - return int(element.item()) + try: + return int(element.item()) + except (TypeError, ValueError) as e: + raise TypeError( + "MLX backend does not support symbolic indices. " + "Index values must be concrete (constant) integers, not symbolic variables. " + f"Got: {element}" + ) from e elif isinstance(element, np.integer): - return int(element) + try: + return int(element) + except (TypeError, ValueError) as e: + raise TypeError( + "MLX backend does not support symbolic indices. " + "Index values must be concrete (constant) integers, not symbolic variables. " + f"Got: {element}" + ) from e else: return element From cd7a2d03ae9f7649e7e27fb78f9ab91c1b7d5017 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Wed, 29 Oct 2025 00:02:10 +0200 Subject: [PATCH 8/9] Simplify error msg --- pytensor/link/mlx/dispatch/subtensor.py | 29 +++++++++++-------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 12c3e91617..2bf848657a 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -22,6 +22,17 @@ def normalize_indices_for_mlx(ilist, idx_list): MLX requires index values to be Python int, not np.int64 or other NumPy types. """ + def to_int(value, element): + """Convert value to Python int with helpful error message.""" + try: + return int(value) + except (TypeError, ValueError) as e: + raise TypeError( + "MLX backend does not support symbolic indices. " + "Index values must be concrete (constant) integers, not symbolic variables. " + f"Got: {element}" + ) from e + def normalize_element(element): if element is None: return None @@ -32,23 +43,9 @@ def normalize_element(element): normalize_element(element.step), ) elif isinstance(element, mx.array) and element.ndim == 0: - try: - return int(element.item()) - except (TypeError, ValueError) as e: - raise TypeError( - "MLX backend does not support symbolic indices. " - "Index values must be concrete (constant) integers, not symbolic variables. " - f"Got: {element}" - ) from e + return to_int(element.item(), element) elif isinstance(element, np.integer): - try: - return int(element) - except (TypeError, ValueError) as e: - raise TypeError( - "MLX backend does not support symbolic indices. " - "Index values must be concrete (constant) integers, not symbolic variables. " - f"Got: {element}" - ) from e + return to_int(element, element) else: return element From 449c2df59e366d114e315a50dab7e0e91377b23c Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:43:21 +0200 Subject: [PATCH 9/9] Refactor MLX subtensor dispatch and update test Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage. --- pytensor/link/mlx/dispatch/subtensor.py | 31 ++++++++++++++----------- tests/link/mlx/test_subtensor.py | 11 +++++---- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 2bf848657a..5c909eb225 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -16,7 +16,7 @@ from pytensor.tensor.type_other import MakeSlice -def normalize_indices_for_mlx(ilist, idx_list): +def normalize_indices_for_mlx(indices): """Convert numpy integers to Python integers for MLX indexing. MLX requires index values to be Python int, not np.int64 or other NumPy types. @@ -49,18 +49,19 @@ def normalize_element(element): else: return element - indices = indices_from_subtensor(ilist, idx_list) return tuple(normalize_element(idx) for idx in indices) @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): """MLX implementation of Subtensor.""" - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list def subtensor(x, *ilists): + # Convert ilist to indices using idx_list (basic subtensor) + indices = indices_from_subtensor(ilists, idx_list) # Normalize indices to handle np.int64 and other NumPy types - indices = normalize_indices_for_mlx(ilists, idx_list) + indices = normalize_indices_for_mlx(indices) if len(indices) == 1: indices = indices[0] @@ -73,11 +74,11 @@ def subtensor(x, *ilists): @mlx_funcify.register(AdvancedSubtensor1) def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): """MLX implementation of AdvancedSubtensor.""" - idx_list = getattr(op, "idx_list", None) def advanced_subtensor(x, *ilists): # Normalize indices to handle np.int64 and other NumPy types - indices = normalize_indices_for_mlx(ilists, idx_list) + # Advanced indexing doesn't use idx_list or indices_from_subtensor + indices = normalize_indices_for_mlx(ilists) if len(indices) == 1: indices = indices[0] @@ -87,12 +88,11 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): """MLX implementation of IncSubtensor.""" - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -109,8 +109,10 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + # Convert ilist to indices using idx_list (basic inc_subtensor) + indices = indices_from_subtensor(ilist, idx_list) # Normalize indices to handle np.int64 and other NumPy types - indices = normalize_indices_for_mlx(ilist, idx_list) + indices = normalize_indices_for_mlx(indices) if len(indices) == 1: indices = indices[0] @@ -121,11 +123,11 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): """MLX implementation of AdvancedIncSubtensor.""" - idx_list = getattr(op, "idx_list", None) - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -141,9 +143,10 @@ def mlx_fn(x, indices, y): x[indices] += y return x - def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): # Normalize indices to handle np.int64 and other NumPy types - indices = normalize_indices_for_mlx(ilist, idx_list) + # Advanced indexing doesn't use idx_list or indices_from_subtensor + indices = normalize_indices_for_mlx(ilist) # For advanced indexing, if we have a single tuple of indices, unwrap it if len(indices) == 1: diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 3fa233fd57..cc4c108956 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -395,13 +395,16 @@ def test_mlx_advanced_incsubtensor_with_numpy_int64(): x_np = np.arange(15, dtype=np.float32).reshape((5, 3)) x_pt = pt.constant(x_np) - # Value to set/increment + # Value to set/increment - using 4 rows now for vector indexing y_pt = pt.as_tensor_variable( - np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + np.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=np.float32, + ) ) - # Advanced indexing set with array indices - indices = [np.int64(0), np.int64(2)] + # Advanced indexing set with vector array indices + indices = np.array([0, 1, 2, 3], dtype=np.int64) out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt) compare_mlx_and_py([], [out_pt], [])