Skip to content

Commit b9e6480

Browse files
Update MOs and Add LMOCSO (#76)
* fix bug of DTLZ7 * add grid sampling * update rvea,ibea and hype * modify dtlz.rst * add lmocso * modify docs * reformat the code * modify docs * add docstring * add docs * fix bug of cos that close to 0 * readme: use 🚀 to represent fast * dev: enable multi-objective optimization for gym * fix: avoid import unnessary packages * fix: typo in readme * flake.lock: Update Flake lock file updates: • Updated input 'flake-compat': 'github:edolstra/flake-compat/35bb57c0c8d8b62bbfd284272c928ceb64ddbde9' (2023-01-17) → 'github:edolstra/flake-compat/0f9255e01c2351cc7d116c072cb317785dd33b33' (2023-10-04) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/d680ded26da5cf104dd2735a51e88d2d8f487b4d' (2023-08-19) → 'github:NixOS/nixpkgs/f99e5f03cc0aa231ab5950a15ed02afec45ed51a' (2023-10-09) • Updated input 'utils': 'github:numtide/flake-utils/919d646de7be200f3bf08cb76ae1f09402b6f9b4' (2023-07-11) → 'github:numtide/flake-utils/ff7b65b44d01cf9ba6a71320833626af21126384' (2023-09-12) * doc: enable markdowm (myst) doc * dev: StdSOMonitor: get_min_fitness -> get_best_fitness * dev: standard the way to specific optimization direction * fix: envpool with 64-bit return values, and add test cases * dev: add support for brax >= 0.9.0 * dev: add envpool to test requirements --------- Co-authored-by: Bill Huang <[email protected]>
1 parent c179feb commit b9e6480

File tree

25 files changed

+424
-176
lines changed

25 files changed

+424
-176
lines changed
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
=======
2+
LMOCSO
3+
=======
4+
5+
.. autoclass:: evox.algorithms.LMOCSO
6+
:members:

docs/source/api/problems/index.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ Problems
88
:maxdepth: 2
99

1010
numerical/index
11-
neuroevolution/index
12-
rl/index
11+
neuroevolution/index

docs/source/api/problems/neuroevolution/index.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Neuroevolution
33
==============
44

55
.. toctree::
6-
:maxdepth: 1
6+
:maxdepth: 2
77

8-
torchvision
8+
reinforcement_learning/index
9+
supervised_learning/index

docs/source/api/problems/rl/brax.rst renamed to docs/source/api/problems/neuroevolution/reinforcement_learning/brax.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
Brax-based Problem
33
==================
44

5-
.. autoclass:: evox.problems.neuroevolution.Brax
5+
.. autoclass:: evox.problems.neuroevolution.reinforcement_learning.Brax
66
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
========
2+
Env Pool
3+
========
4+
5+
.. autoclass:: evox.problems.neuroevolution.reinforcement_learning.EnvPool
6+
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
===
2+
Gym
3+
===
4+
5+
.. autoclass:: evox.problems.neuroevolution.reinforcement_learning.Gym
6+
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
======================
2+
Reinforcement Learning
3+
======================
4+
5+
.. toctree::
6+
:maxdepth: 1
7+
8+
brax
9+
gym
10+
env_pool
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
===================
2+
Supervised Learning
3+
===================
4+
5+
.. toctree::
6+
:maxdepth: 1
7+
8+
torchvision
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
===================
2+
Torchvision Dataset
3+
===================
4+
5+
.. autoclass:: evox.problems.neuroevolution.supervised_learning.TorchvisionDataset
6+
:members:

docs/source/api/problems/neuroevolution/torchvision.rst

-6
This file was deleted.

docs/source/api/problems/numerical/dtlz.rst

-5
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,3 @@ DTLZ Test Suit
2323
.. autoclass:: evox.problems.numerical.DTLZ7
2424
:members:
2525

26-
.. autoclass:: evox.problems.numerical.DTLZ8
27-
:members:
28-
29-
.. autoclass:: evox.problems.numerical.DTLZ9
30-
:members:

docs/source/api/problems/rl/gym.rst

-6
This file was deleted.

docs/source/api/problems/rl/index.rst

-9
This file was deleted.

src/evox/algorithms/mo/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
from .sra import SRA
1515
from .tdea import TDEA
1616
from .bce_ibea import BCEIBEA
17+
from .lmocso import LMOCSO
18+

src/evox/algorithms/mo/hype.py

+19-30
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,36 @@
1717
from evox.operators import selection, mutation, crossover, non_dominated_sort
1818

1919

20-
@partial(jax.jit, static_argnums=[0, 1])
21-
def calculate_alpha(N, k):
22-
alpha = jnp.zeros(N)
23-
24-
for i in range(1, k + 1):
25-
num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i)))
26-
alpha = alpha.at[i - 1].set(num / i)
27-
return alpha
28-
29-
3020
@partial(jax.jit, static_argnums=[2, 3])
3121
def cal_hv(points, ref, k, n_sample, key):
3222
n, m = jnp.shape(points)
33-
alpha = calculate_alpha(n, k)
23+
24+
# hit in alpha relevant partition
25+
alpha = jnp.cumprod(
26+
jnp.r_[1, (k - jnp.arange(1, n)) / (n - jnp.arange(1, n))]
27+
) / jnp.arange(1, n + 1)
3428

3529
f_min = jnp.min(points, axis=0)
3630

37-
s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref)
31+
samples = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref)
3832

39-
pds = jnp.zeros((n, n_sample), dtype=bool)
33+
# update hypervolume estimates
4034
ds = jnp.zeros((n_sample,))
41-
42-
def body_fun1(i, vals):
43-
pds, ds = vals
44-
x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m
45-
pds = pds.at[i].set(jnp.where(x, True, pds[i]))
46-
ds = jnp.where(x, ds + 1, ds)
47-
return pds, ds
48-
49-
pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds))
50-
ds = ds - 1
51-
52-
f = jnp.zeros((n,))
53-
54-
def body_fun2(pd):
55-
temp = jnp.where(pd, ds, -1).astype(int)
35+
pds = jax.vmap(
36+
lambda x: jnp.sum((jnp.tile(x, (n_sample, 1)) - samples) <= 0, axis=1) == m,
37+
in_axes=0,
38+
out_axes=0,
39+
)(points)
40+
ds = jnp.sum(jnp.where(pds, ds + 1, ds), axis=0)
41+
ds = jnp.where(ds == 0, ds, ds - 1)
42+
43+
def cal_f(val):
44+
temp = jnp.where(val, ds, -1).astype(int)
5645
value = jnp.where(temp != -1, alpha[temp], 0)
5746
value = jnp.sum(value)
5847
return value
59-
60-
f = jax.vmap(body_fun2)(pds)
48+
49+
f = jax.vmap(cal_f, in_axes=0, out_axes=0)(pds)
6150
f = f * jnp.prod(ref - f_min) / n_sample
6251

6352
return f

src/evox/algorithms/mo/ibea.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class IBEA(Algorithm):
3737
"""IBEA algorithm
3838
3939
link: https://link.springer.com/chapter/10.1007/978-3-540-30217-9_84
40+
41+
Args:
42+
kappa: fitness scaling factor. Default: 0.05
4043
"""
4144

4245
def __init__(
@@ -111,27 +114,31 @@ def _tell_normal(self, state, fitness):
111114
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
112115
merged_obj = jnp.concatenate([state.fitness, fitness], axis=0)
113116

114-
n = jnp.shape(merged_pop)[0]
115117
merged_fitness, I, C = cal_fitness(merged_obj, self.kappa)
116118

117-
next_ind = jnp.arange(n)
118-
vals = (next_ind, merged_fitness)
119-
120-
def body_fun(i, vals):
121-
next_ind, merged_fitness = vals
122-
x = jnp.argmin(merged_fitness)
123-
merged_fitness += jnp.exp(-I[x, :] / C[x] / self.kappa)
124-
merged_fitness = merged_fitness.at[x].set(jnp.max(merged_fitness))
125-
next_ind = next_ind.at[x].set(-1)
126-
return (next_ind, merged_fitness)
127-
128-
next_ind, merged_fitness = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)
129-
130-
ind = jnp.where(next_ind != -1, size=n, fill_value=-1)[0]
131-
ind_n = ind[0 : self.pop_size]
132-
133-
survivor = merged_pop[ind_n]
134-
survivor_fitness = merged_obj[ind_n]
119+
# Different from the original paper, the selection here is directly through fitness.
120+
next_ind = jnp.argsort(-merged_fitness)[0: self.pop_size]
121+
122+
# The following code is from the original paper's implementation
123+
# and is kept for reference purposes but is not being used in this version.
124+
# n = jnp.shape(merged_pop)[0]
125+
# next_ind = jnp.arange(n)
126+
# vals = (next_ind, merged_fitness)
127+
# def body_fun(i, vals):
128+
# next_ind, merged_fitness = vals
129+
# x = jnp.argmin(merged_fitness)
130+
# merged_fitness += jnp.exp(-I[x, :] / C[x] / self.kappa)
131+
# merged_fitness = merged_fitness.at[x].set(jnp.max(merged_fitness))
132+
# next_ind = next_ind.at[x].set(-1)
133+
# return (next_ind, merged_fitness)
134+
#
135+
# next_ind, merged_fitness = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)
136+
#
137+
# next_ind = jnp.where(next_ind != -1, size=n, fill_value=-1)[0]
138+
# next_ind = next_ind[0: self.pop_size]
139+
140+
survivor = merged_pop[next_ind]
141+
survivor_fitness = merged_obj[next_ind]
135142

136143
state = state.update(population=survivor, fitness=survivor_fitness)
137144

0 commit comments

Comments
 (0)