|
| 1 | +--- |
| 2 | +title: "Pytrees for Scientific Python" |
| 3 | +date: 2025-07-08 |
| 4 | +draft: false |
| 5 | +description: " |
| 6 | +Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them. |
| 7 | +" |
| 8 | +tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"] |
| 9 | +displayInList: true |
| 10 | +author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"] |
| 11 | +summary: | |
| 12 | + This blog introduces PyTrees — nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values — designed to simplify working with complex, hierarchically organized data. |
| 13 | + While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way. |
| 14 | + This enables flexible, generic operations like mapping and reducing from functional programming. |
| 15 | + By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure — no matter how deeply nested or complex it is. |
| 16 | +--- |
| 17 | + |
| 18 | +## Manipulating Tree-like Data using Functional Programming Paradigms |
| 19 | + |
| 20 | +A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest. |
| 21 | +In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure. |
| 22 | +Such collections can be cumbersome to manipulate _efficiently_, especially if they are nested any depth. |
| 23 | +It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements. |
| 24 | + |
| 25 | +The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree. |
| 26 | +This allows for the application of generic transformations. |
| 27 | +In this blog post, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) — a standalone PyTree library — that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency. |
| 28 | +For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `optree.tree_map(np.sqrt, tree)`: |
| 29 | + |
| 30 | +```python |
| 31 | +import optree as pt |
| 32 | +import numpy as np |
| 33 | + |
| 34 | +# tuple of a list of a dict with an array as value, and an array |
| 35 | +tree = ([[{"foo": np.array([4.0])}], np.array([9.0])],) |
| 36 | + |
| 37 | +# sqrt of each leaf array |
| 38 | +sqrt_tree = pt.tree_map(np.sqrt, tree) |
| 39 | +print(f"{sqrt_tree=}") |
| 40 | +# >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],) |
| 41 | + |
| 42 | +# reductions |
| 43 | +all_positive = all(np.all(x > 0.0) for x in pt.tree_iter(tree)) |
| 44 | +print(f"{all_positive=}") |
| 45 | +# >> all_positive=True |
| 46 | + |
| 47 | +summed = np.sum(pt.tree_reduce(sum, tree)) |
| 48 | +print(f"{summed=}") |
| 49 | +# >> summed=np.float64(13.0) |
| 50 | +``` |
| 51 | + |
| 52 | +The trick here is that these operations can be implemented in three steps, e.g. `tree_map`: |
| 53 | + |
| 54 | +```python |
| 55 | +# step 1: |
| 56 | +leaves, treedef = pt.tree_flatten(tree) |
| 57 | + |
| 58 | +# step 2: |
| 59 | +new_leaves = tuple(map(fun, leaves)) |
| 60 | + |
| 61 | +# step 3: |
| 62 | +result_tree = pt.tree_unflatten(treedef, new_leaves) |
| 63 | +``` |
| 64 | + |
| 65 | +### PyTree Origins |
| 66 | + |
| 67 | +Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays). |
| 68 | +However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet: |
| 69 | + |
| 70 | +```python |
| 71 | +from typing import NamedTuple, Callable |
| 72 | +import jax |
| 73 | +import jax.numpy as jnp |
| 74 | + |
| 75 | + |
| 76 | +class Layer(NamedTuple): |
| 77 | + W: jax.Array |
| 78 | + b: jax.Array |
| 79 | + |
| 80 | + |
| 81 | +layers = [ |
| 82 | + Layer(W=jnp.array(...), b=jnp.array(...)), # first layer |
| 83 | + Layer(W=jnp.array(...), b=jnp.array(...)), # second layer |
| 84 | + ..., |
| 85 | +] |
| 86 | + |
| 87 | + |
| 88 | +@jax.jit |
| 89 | +def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array: |
| 90 | + for layer in layers: |
| 91 | + x = jnp.tanh(layer.W @ x + layer.b) |
| 92 | + return x |
| 93 | + |
| 94 | + |
| 95 | +prediction = neural_network(layers=layers, x=jnp.array(...)) |
| 96 | +``` |
| 97 | + |
| 98 | +Here, `layers` is a PyTree — a `list` of multiple `Layer` — and the JIT compiled `neural_network` function _just works_ with this data structure as input. |
| 99 | +Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python `list` of `NamedTuples`. |
| 100 | + |
| 101 | +### PyTrees in Scientific Python |
| 102 | + |
| 103 | +Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree? |
| 104 | + |
| 105 | +Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well. |
| 106 | +Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function: |
| 107 | + |
| 108 | +```python |
| 109 | +from scipy.optimize import minimize |
| 110 | + |
| 111 | + |
| 112 | +def rosenbrock(params: tuple[float]) -> float: |
| 113 | + """ |
| 114 | + Rosenbrock function. Minimum: f(1, 1) = 0. |
| 115 | +
|
| 116 | + https://en.wikipedia.org/wiki/Rosenbrock_function |
| 117 | + """ |
| 118 | + x, y = params |
| 119 | + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 |
| 120 | + |
| 121 | + |
| 122 | +x0 = (0.9, 1.2) |
| 123 | +res = minimize(rosenbrock, x0) |
| 124 | +print(res.x) |
| 125 | +# >> [0.99999569 0.99999137] |
| 126 | +``` |
| 127 | + |
| 128 | +Now, let's consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters: |
| 129 | + |
| 130 | +```python |
| 131 | +import optree as pt |
| 132 | +from typing import NamedTuple, Callable |
| 133 | +from scipy.optimize import minimize as sp_minimize |
| 134 | + |
| 135 | + |
| 136 | +class Params(NamedTuple): |
| 137 | + x: float |
| 138 | + y: float |
| 139 | + |
| 140 | + |
| 141 | +def rosenbrock(params: Params) -> float: |
| 142 | + """ |
| 143 | + Rosenbrock function. Minimum: f(1, 1) = 0. |
| 144 | +
|
| 145 | + https://en.wikipedia.org/wiki/Rosenbrock_function |
| 146 | + """ |
| 147 | + return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2 |
| 148 | + |
| 149 | + |
| 150 | +def minimize(fun: Callable, params: Params) -> Params: |
| 151 | + # flatten and store PyTree definition |
| 152 | + flat_params, treedef = pt.tree_flatten(params) |
| 153 | + |
| 154 | + # wrap fun to work with flat_params |
| 155 | + def wrapped_fun(flat_params): |
| 156 | + params = pt.tree_unflatten(treedef, flat_params) |
| 157 | + return fun(params) |
| 158 | + |
| 159 | + # actual minimization |
| 160 | + res = sp_minimize(wrapped_fun, flat_params) |
| 161 | + |
| 162 | + # re-wrap the bestfit values into Params with stored PyTree definition |
| 163 | + return pt.tree_unflatten(treedef, res.x) |
| 164 | + |
| 165 | + |
| 166 | +# scipy minimize that works with any PyTree |
| 167 | +x0 = Params(x=0.9, y=1.2) |
| 168 | +bestfit_params = minimize(rosenbrock, x0) |
| 169 | +print(bestfit_params) |
| 170 | +# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226)) |
| 171 | +``` |
| 172 | + |
| 173 | +This new `minimize` function works with _any_ PyTree! |
| 174 | + |
| 175 | +Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input — a common pattern for hierarchical models (e.g. a superposition of various probability density functions): |
| 176 | + |
| 177 | +```python |
| 178 | +import numpy as np |
| 179 | + |
| 180 | + |
| 181 | +def rosenbrock_modified(two_params: tuple[Params, Params]) -> float: |
| 182 | + """ |
| 183 | + Modified Rosenbrock where the x and y parameters are determined by |
| 184 | + a non-linear transformations of two versions of each, i.e.: |
| 185 | + x = arcsin(min(x1, x2) / max(x1, x2)) |
| 186 | + y = sigmoid(x1 - x2) |
| 187 | + """ |
| 188 | + p1, p2 = two_params |
| 189 | + |
| 190 | + # calculate `x` and `y` from two sources: |
| 191 | + x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x)) |
| 192 | + y = 1 / (1 + np.exp(-(p1.y / p2.y))) |
| 193 | + |
| 194 | + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 |
| 195 | + |
| 196 | + |
| 197 | +x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3)) |
| 198 | +bestfit_params = minimize(rosenbrock_modified, x0) |
| 199 | +print(bestfit_params) |
| 200 | +# >> ( |
| 201 | +# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)), |
| 202 | +# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)), |
| 203 | +# ) |
| 204 | +``` |
| 205 | + |
| 206 | +The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree! |
| 207 | + |
| 208 | +### Final Thought |
| 209 | + |
| 210 | +Working with nested data structures doesn’t have to be messy. |
| 211 | +PyTrees let you focus on the data and the transformations you want to apply, in a generic manner. |
| 212 | +Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with. |
0 commit comments