Skip to content

Commit 2ba060e

Browse files
committed
feat: Introduce a new benchmarking harness and add several new benchmarks for Xtructure components.
1 parent 1a17fd1 commit 2ba060e

12 files changed

Lines changed: 1200 additions & 124 deletions
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import argparse
2+
3+
import jax
4+
import jax.numpy as jnp
5+
6+
from xtructure import BGPQ, FieldDescriptor, xtructure_dataclass
7+
from xtructure_benchmarks.harness import (
8+
BenchmarkResult,
9+
add_harness_args,
10+
configure_precision,
11+
finalize_result,
12+
run_case,
13+
)
14+
15+
16+
@xtructure_dataclass
17+
class SmallValue:
18+
x: FieldDescriptor.scalar(dtype=jnp.uint32)
19+
20+
21+
@xtructure_dataclass
22+
class BigValue:
23+
x: FieldDescriptor.tensor(dtype=jnp.float32, shape=(64,))
24+
25+
26+
def _pop_process_mask(keys: jax.Array, pop_ratio: float, min_pop: int) -> jax.Array:
27+
filled = jnp.isfinite(keys)
28+
mult = jnp.maximum(1.0 + pop_ratio, 1.01)
29+
threshold = keys[0] * mult + 1e-6
30+
base = jnp.logical_and(filled, keys <= threshold)
31+
min_mask = jnp.logical_and(jnp.cumsum(filled) <= min_pop, filled)
32+
return jnp.logical_or(base, min_mask)
33+
34+
35+
def _parse_batch_sizes(batch_sizes: str) -> list[int]:
36+
if not batch_sizes:
37+
return []
38+
return [int(x.strip()) for x in batch_sizes.split(",") if x.strip()]
39+
40+
41+
def main() -> None:
42+
parser = argparse.ArgumentParser(description="BGPQ steady-state workload benchmark")
43+
add_harness_args(parser)
44+
parser.add_argument("--max-nodes", type=int, default=2**18)
45+
parser.add_argument("--batch-size", type=int, default=1024)
46+
parser.add_argument(
47+
"--batch-sizes",
48+
type=str,
49+
default="",
50+
help="Comma-separated batch sizes (e.g. 1024,4096,16384). Overrides --batch-size.",
51+
)
52+
parser.add_argument("--prefill", type=int, default=32)
53+
parser.add_argument("--branching-factor", type=int, default=2)
54+
parser.add_argument("--min-pop", type=int, default=1)
55+
parser.add_argument("--pop-ratio", type=float, default=0.5)
56+
parser.add_argument(
57+
"--pop-calls",
58+
type=int,
59+
default=1,
60+
help="Number of delete_mins() calls per step (default: 1).",
61+
)
62+
parser.add_argument(
63+
"--value-kind", choices=["u32_small", "big_payload"], default="u32_small"
64+
)
65+
parser.add_argument(
66+
"--output",
67+
type=str,
68+
default="xtructure_benchmarks/results/bgpq_workload_results.json",
69+
)
70+
args = parser.parse_args()
71+
configure_precision(args)
72+
73+
batch_sizes = _parse_batch_sizes(args.batch_sizes) or [args.batch_size]
74+
value_cls = SmallValue if args.value_kind == "u32_small" else BigValue
75+
result = BenchmarkResult()
76+
77+
base_key = jax.random.PRNGKey(args.seed)
78+
79+
for batch_size in batch_sizes:
80+
heap = BGPQ.build(args.max_nodes, batch_size, value_cls, jnp.float32)
81+
82+
# Derive per-batch-size RNG so multi-size runs stay reproducible.
83+
k0 = jax.random.fold_in(base_key, int(batch_size))
84+
k_prefill_keys, k_prefill_vals, run_key = jax.random.split(k0, 3)
85+
pre_keys = jax.random.uniform(
86+
k_prefill_keys, (args.prefill, batch_size), dtype=jnp.float32
87+
)
88+
pre_vals = value_cls.random(
89+
shape=(args.prefill, batch_size), key=k_prefill_vals
90+
)
91+
92+
@jax.jit
93+
def prefill(h, keys, vals):
94+
def body(i, carry):
95+
return carry.insert(
96+
keys[i], jax.tree_util.tree_map(lambda x: x[i], vals)
97+
)
98+
99+
return jax.lax.fori_loop(0, keys.shape[0], body, h)
100+
101+
heap = prefill(heap, pre_keys, pre_vals)
102+
jax.block_until_ready(heap)
103+
104+
@jax.jit
105+
def one_step(state):
106+
h, k = state
107+
k, sk, sv = jax.random.split(k, 3)
108+
child_keys = jax.random.uniform(
109+
sk,
110+
(args.pop_calls, args.branching_factor, batch_size),
111+
dtype=jnp.float32,
112+
)
113+
child_vals = value_cls.random(
114+
shape=(args.pop_calls, args.branching_factor, batch_size), key=sv
115+
)
116+
117+
processed = jnp.asarray(0, dtype=jnp.int32)
118+
119+
def pop_body(i, carry):
120+
hp, proc = carry
121+
hp, popped_keys, popped_vals = hp.delete_mins()
122+
process_mask = _pop_process_mask(
123+
popped_keys, args.pop_ratio, args.min_pop
124+
)
125+
requeue_keys = jnp.where(process_mask, jnp.inf, popped_keys)
126+
hp = hp.insert(requeue_keys, popped_vals)
127+
128+
row_child_keys = child_keys[i]
129+
row_child_vals = jax.tree_util.tree_map(lambda x: x[i], child_vals)
130+
131+
def insert_child_row(j, hcarry):
132+
masked_keys = jnp.where(process_mask, row_child_keys[j], jnp.inf)
133+
row_vals = jax.tree_util.tree_map(lambda x: x[j], row_child_vals)
134+
return hcarry.insert(masked_keys, row_vals)
135+
136+
hp = jax.lax.fori_loop(0, args.branching_factor, insert_child_row, hp)
137+
return hp, proc + jnp.sum(process_mask, dtype=jnp.int32)
138+
139+
h, processed = jax.lax.fori_loop(
140+
0, args.pop_calls, pop_body, (h, processed)
141+
)
142+
return (h, k), processed
143+
144+
@jax.jit
145+
def run_inner(state):
146+
def body(_, carry):
147+
core, total_processed = carry
148+
core, processed = one_step(core)
149+
return core, total_processed + processed
150+
151+
return jax.lax.fori_loop(
152+
0, args.inner_steps, body, (state, jnp.asarray(0, jnp.int32))
153+
)
154+
155+
def fn():
156+
(h, _), processed = run_inner((heap, run_key))
157+
if args.transfer_policy == "payload_only":
158+
return processed
159+
return (h.heap_size, h.buffer_size, processed)
160+
161+
candidates_per_call = (
162+
batch_size * args.branching_factor * args.pop_calls * args.inner_steps
163+
)
164+
run_case(
165+
result,
166+
name="bgpq_frontier_step",
167+
params={
168+
"max_nodes": args.max_nodes,
169+
"batch_size": int(batch_size),
170+
"prefill": args.prefill,
171+
"branching_factor": args.branching_factor,
172+
"min_pop": args.min_pop,
173+
"pop_ratio": args.pop_ratio,
174+
"pop_calls": args.pop_calls,
175+
"value_kind": args.value_kind,
176+
},
177+
payload_items=int(candidates_per_call),
178+
fn=fn,
179+
args=args,
180+
)
181+
182+
finalize_result(result, args, args.output, extra_run={"batch_sizes": batch_sizes})
183+
184+
185+
if __name__ == "__main__":
186+
main()

0 commit comments

Comments
 (0)