Skip to content

Commit c3d37b8

Browse files
pfackeldeymatthewfeickertmihaimaruseacstefanvrossbar
authored
Add blog post about Pytrees for Scientific Python (#250)
* wip * first complete version * first round of review improvements * add final thought section * improve code snippet for modified rosenbrock * remove obsolete comment * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> * Update content/posts/optree/pytrees/index.md Co-authored-by: Mihai Maruseac <mihai.maruseac@gmail.com> * Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt <stefan@mentat.za.net> * add summary * Revert "add summary" This reverts commit bc98572. * add summary * Update content/posts/optree/pytrees/index.md Co-authored-by: Ross Barnowski <rossbar@caltech.edu> * Update content/posts/optree/pytrees/index.md Co-authored-by: Ross Barnowski <rossbar@caltech.edu> * Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt <stefan@mentat.za.net> * Update content/posts/optree/pytrees/index.md Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com> * pytree -> tree * be more specific about the motivation for scientific data * mention optree earlier * be a bit more specific about the use case with hierarchical models * leafs -> leaves * Update content/posts/optree/pytrees/index.md Co-authored-by: Stefan van der Walt <stefan@mentat.za.net> * be more explicit about reducing arrays with more than 1 element * add a sentence of how pytrees and jax.jit work together * clarify what 'compiler' is meant with * Set date --------- Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch> Co-authored-by: Mihai Maruseac <mihai.maruseac@gmail.com> Co-authored-by: Stefan van der Walt <stefan@mentat.za.net> Co-authored-by: Ross Barnowski <rossbar@caltech.edu> Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
1 parent 84191db commit c3d37b8

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
---
2+
title: "Pytrees for Scientific Python"
3+
date: 2025-07-08
4+
draft: false
5+
description: "
6+
Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they're useful in the realm of scientific Python, and how to work _efficiently_ with them.
7+
"
8+
tags: ["PyTrees", "Functional Programming", "Tree-like data manipulation"]
9+
displayInList: true
10+
author: ["Peter Fackeldey", "Mihai Maruseac", "Matthew Feickert"]
11+
summary: |
12+
This blog introduces PyTrees &mdash; nested Python data structures (such as lists, dicts, and tuples) with numerical leaf values &mdash; designed to simplify working with complex, hierarchically organized data.
13+
While such structures are often cumbersome to manipulate, PyTrees make them more manageable by allowing them to be flattened into a list of leaves along with a reusable structure blueprint in a _generic_ way.
14+
This enables flexible, generic operations like mapping and reducing from functional programming.
15+
By bringing those functional paradigms to structured data, PyTrees let you focus on what transformations to apply, not how to traverse the structure &mdash; no matter how deeply nested or complex it is.
16+
---
17+
18+
## Manipulating Tree-like Data using Functional Programming Paradigms
19+
20+
A "PyTree" is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, ...), where the leaves are of interest.
21+
In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure.
22+
Such collections can be cumbersome to manipulate _efficiently_, especially if they are nested any depth.
23+
It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements.
24+
25+
The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a "blueprint" of the tree structure, and then being able to unflatten them back into the original PyTree.
26+
This allows for the application of generic transformations.
27+
In this blog post, we use [`optree`](https://github.com/metaopt/optree/tree/main/optree) &mdash; a standalone PyTree library &mdash; that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by [PyTorch](https://pytorch.org), [Keras](https://keras.io), and [TensorFlow](https://github.com/tensorflow/tensorflow) (through Keras) as a core dependency.
28+
For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with `optree.tree_map(np.sqrt, tree)`:
29+
30+
```python
31+
import optree as pt
32+
import numpy as np
33+
34+
# tuple of a list of a dict with an array as value, and an array
35+
tree = ([[{"foo": np.array([4.0])}], np.array([9.0])],)
36+
37+
# sqrt of each leaf array
38+
sqrt_tree = pt.tree_map(np.sqrt, tree)
39+
print(f"{sqrt_tree=}")
40+
# >> sqrt_tree=([[{'foo': array([2.])}], array([3.])],)
41+
42+
# reductions
43+
all_positive = all(np.all(x > 0.0) for x in pt.tree_iter(tree))
44+
print(f"{all_positive=}")
45+
# >> all_positive=True
46+
47+
summed = np.sum(pt.tree_reduce(sum, tree))
48+
print(f"{summed=}")
49+
# >> summed=np.float64(13.0)
50+
```
51+
52+
The trick here is that these operations can be implemented in three steps, e.g. `tree_map`:
53+
54+
```python
55+
# step 1:
56+
leaves, treedef = pt.tree_flatten(tree)
57+
58+
# step 2:
59+
new_leaves = tuple(map(fun, leaves))
60+
61+
# step 3:
62+
result_tree = pt.tree_unflatten(treedef, new_leaves)
63+
```
64+
65+
### PyTree Origins
66+
67+
Originally, the concept of PyTrees was developed by the [JAX](https://docs.jax.dev/en/latest/) project to make nested collections of JAX arrays work transparently at the "JIT-boundary" (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays).
68+
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet:
69+
70+
```python
71+
from typing import NamedTuple, Callable
72+
import jax
73+
import jax.numpy as jnp
74+
75+
76+
class Layer(NamedTuple):
77+
W: jax.Array
78+
b: jax.Array
79+
80+
81+
layers = [
82+
Layer(W=jnp.array(...), b=jnp.array(...)), # first layer
83+
Layer(W=jnp.array(...), b=jnp.array(...)), # second layer
84+
...,
85+
]
86+
87+
88+
@jax.jit
89+
def neural_network(layers: list[Layer], x: jax.Array) -> jax.Array:
90+
for layer in layers:
91+
x = jnp.tanh(layer.W @ x + layer.b)
92+
return x
93+
94+
95+
prediction = neural_network(layers=layers, x=jnp.array(...))
96+
```
97+
98+
Here, `layers` is a PyTree &mdash; a `list` of multiple `Layer` &mdash; and the JIT compiled `neural_network` function _just works_ with this data structure as input.
99+
Although you cannot see what happens inside of `jax.jit`, `layers` is automatically flattened by the `jax.jit` decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python `list` of `NamedTuples`.
100+
101+
### PyTrees in Scientific Python
102+
103+
Wouldn't it be nice to make workflows in the scientific Python ecosystem _just work_ with any PyTree?
104+
105+
Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well.
106+
Consider the following minimization of the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function:
107+
108+
```python
109+
from scipy.optimize import minimize
110+
111+
112+
def rosenbrock(params: tuple[float]) -> float:
113+
"""
114+
Rosenbrock function. Minimum: f(1, 1) = 0.
115+
116+
https://en.wikipedia.org/wiki/Rosenbrock_function
117+
"""
118+
x, y = params
119+
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
120+
121+
122+
x0 = (0.9, 1.2)
123+
res = minimize(rosenbrock, x0)
124+
print(res.x)
125+
# >> [0.99999569 0.99999137]
126+
```
127+
128+
Now, let's consider a minimization that uses a more complex type for the parameters &mdash; a NamedTuple that describes our fit parameters:
129+
130+
```python
131+
import optree as pt
132+
from typing import NamedTuple, Callable
133+
from scipy.optimize import minimize as sp_minimize
134+
135+
136+
class Params(NamedTuple):
137+
x: float
138+
y: float
139+
140+
141+
def rosenbrock(params: Params) -> float:
142+
"""
143+
Rosenbrock function. Minimum: f(1, 1) = 0.
144+
145+
https://en.wikipedia.org/wiki/Rosenbrock_function
146+
"""
147+
return (1 - params.x) ** 2 + 100 * (params.y - params.x**2) ** 2
148+
149+
150+
def minimize(fun: Callable, params: Params) -> Params:
151+
# flatten and store PyTree definition
152+
flat_params, treedef = pt.tree_flatten(params)
153+
154+
# wrap fun to work with flat_params
155+
def wrapped_fun(flat_params):
156+
params = pt.tree_unflatten(treedef, flat_params)
157+
return fun(params)
158+
159+
# actual minimization
160+
res = sp_minimize(wrapped_fun, flat_params)
161+
162+
# re-wrap the bestfit values into Params with stored PyTree definition
163+
return pt.tree_unflatten(treedef, res.x)
164+
165+
166+
# scipy minimize that works with any PyTree
167+
x0 = Params(x=0.9, y=1.2)
168+
bestfit_params = minimize(rosenbrock, x0)
169+
print(bestfit_params)
170+
# >> Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226))
171+
```
172+
173+
This new `minimize` function works with _any_ PyTree!
174+
175+
Let's now consider a modified and more complex version of the Rosenbrock function that relies on two sets of `Params` as input &mdash; a common pattern for hierarchical models (e.g. a superposition of various probability density functions):
176+
177+
```python
178+
import numpy as np
179+
180+
181+
def rosenbrock_modified(two_params: tuple[Params, Params]) -> float:
182+
"""
183+
Modified Rosenbrock where the x and y parameters are determined by
184+
a non-linear transformations of two versions of each, i.e.:
185+
x = arcsin(min(x1, x2) / max(x1, x2))
186+
y = sigmoid(x1 - x2)
187+
"""
188+
p1, p2 = two_params
189+
190+
# calculate `x` and `y` from two sources:
191+
x = np.asin(min(p1.x, p2.x) / max(p1.x, p2.x))
192+
y = 1 / (1 + np.exp(-(p1.y / p2.y)))
193+
194+
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
195+
196+
197+
x0 = (Params(x=0.9, y=1.2), Params(x=0.8, y=1.3))
198+
bestfit_params = minimize(rosenbrock_modified, x0)
199+
print(bestfit_params)
200+
# >> (
201+
# Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)),
202+
# Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)),
203+
# )
204+
```
205+
206+
The new `minimize` still works, because a `tuple` of `Params` is just _another_ PyTree!
207+
208+
### Final Thought
209+
210+
Working with nested data structures doesn’t have to be messy.
211+
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner.
212+
Whether you're building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with.

0 commit comments

Comments
 (0)