|
17 | 17 | from evox.operators import selection, mutation, crossover, non_dominated_sort
|
18 | 18 |
|
19 | 19 |
|
20 |
| -@partial(jax.jit, static_argnums=[0, 1]) |
21 |
| -def calculate_alpha(N, k): |
22 |
| - alpha = jnp.zeros(N) |
23 |
| - |
24 |
| - for i in range(1, k + 1): |
25 |
| - num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i))) |
26 |
| - alpha = alpha.at[i - 1].set(num / i) |
27 |
| - return alpha |
28 |
| - |
29 |
| - |
30 | 20 | @partial(jax.jit, static_argnums=[2, 3])
|
31 | 21 | def cal_hv(points, ref, k, n_sample, key):
|
32 | 22 | n, m = jnp.shape(points)
|
33 |
| - alpha = calculate_alpha(n, k) |
| 23 | + |
| 24 | + # hit in alpha relevant partition |
| 25 | + alpha = jnp.cumprod( |
| 26 | + jnp.r_[1, (k - jnp.arange(1, n)) / (n - jnp.arange(1, n))] |
| 27 | + ) / jnp.arange(1, n + 1) |
34 | 28 |
|
35 | 29 | f_min = jnp.min(points, axis=0)
|
36 | 30 |
|
37 |
| - s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref) |
| 31 | + samples = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref) |
38 | 32 |
|
39 |
| - pds = jnp.zeros((n, n_sample), dtype=bool) |
| 33 | + # update hypervolume estimates |
40 | 34 | ds = jnp.zeros((n_sample,))
|
41 |
| - |
42 |
| - def body_fun1(i, vals): |
43 |
| - pds, ds = vals |
44 |
| - x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m |
45 |
| - pds = pds.at[i].set(jnp.where(x, True, pds[i])) |
46 |
| - ds = jnp.where(x, ds + 1, ds) |
47 |
| - return pds, ds |
48 |
| - |
49 |
| - pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds)) |
50 |
| - ds = ds - 1 |
51 |
| - |
52 |
| - f = jnp.zeros((n,)) |
53 |
| - |
54 |
| - def body_fun2(pd): |
55 |
| - temp = jnp.where(pd, ds, -1).astype(int) |
| 35 | + pds = jax.vmap( |
| 36 | + lambda x: jnp.sum((jnp.tile(x, (n_sample, 1)) - samples) <= 0, axis=1) == m, |
| 37 | + in_axes=0, |
| 38 | + out_axes=0, |
| 39 | + )(points) |
| 40 | + ds = jnp.sum(jnp.where(pds, ds + 1, ds), axis=0) |
| 41 | + ds = jnp.where(ds == 0, ds, ds - 1) |
| 42 | + |
| 43 | + def cal_f(val): |
| 44 | + temp = jnp.where(val, ds, -1).astype(int) |
56 | 45 | value = jnp.where(temp != -1, alpha[temp], 0)
|
57 | 46 | value = jnp.sum(value)
|
58 | 47 | return value
|
59 |
| - |
60 |
| - f = jax.vmap(body_fun2)(pds) |
| 48 | + |
| 49 | + f = jax.vmap(cal_f, in_axes=0, out_axes=0)(pds) |
61 | 50 | f = f * jnp.prod(ref - f_min) / n_sample
|
62 | 51 |
|
63 | 52 | return f
|
|
0 commit comments