-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_utils.py
38 lines (34 loc) · 1.42 KB
/
tree_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from jax import numpy as jnp
import jax as jx
#Adapted from https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75
def tree_stack(trees):
"""Takes a list of trees and stacks every corresponding leaf.
For example, given two trees ((a, b), c) and ((a', b'), c'), returns
((stack(a, a'), stack(b, b')), stack(c, c')).
Useful for turning a list of objects into something you can feed to a
vmapped function.
"""
leaves_list = []
treedef_list = []
for tree in trees:
leaves, treedef = jx.tree_util.tree_flatten(tree)
leaves_list.append(leaves)
treedef_list.append(treedef)
grouped_leaves = zip(*leaves_list)
result_leaves = [jnp.stack(l) for l in grouped_leaves]
return treedef_list[0].unflatten(result_leaves)
def tree_unstack(tree):
"""Takes a tree and turns it into a list of trees. Inverse of tree_stack.
For example, given a tree ((a, b), c), where a, b, and c all have first
dimension k, will make k trees
[((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
Useful for turning the output of a vmapped function into normal objects.
"""
leaves, treedef = jx.tree_util.tree_flatten(tree)
n_trees = leaves[0].shape[0]
new_leaves = [[] for _ in range(n_trees)]
for leaf in leaves:
for i in range(n_trees):
new_leaves[i].append(leaf[i])
new_trees = [treedef.unflatten(l) for l in new_leaves]
return new_trees