diff --git a/src/evox/algorithms/__init__.py b/src/evox/algorithms/__init__.py index 7063be560..e135707a1 100644 --- a/src/evox/algorithms/__init__.py +++ b/src/evox/algorithms/__init__.py @@ -1,10 +1,23 @@ __all__ = [ # DE Variants "DE", + "SHADE", + "CoDE", + "SaDE", "ODE", "JaDE", # ES Variants "OpenES", + "XNES", + "SeparableNES", + "DES", + "SNES", + "ARS", + "ASEBO", + "PersistentES", + "NoiseReuseES", + "GuidedES", + "ESMC", "CMAES", # PSO Variants "CLPSO", @@ -21,7 +34,7 @@ ] -from .de_variants import DE, ODE, JaDE -from .es_variants import CMAES, OpenES +from .de_variants import DE, ODE, SHADE, CoDE, JaDE, SaDE +from .es_variants import ARS, ASEBO, CMAES, DES, ESMC, SNES, XNES, GuidedES, NoiseReuseES, OpenES, PersistentES, SeparableNES from .mo import MOEAD, NSGA2, RVEA from .pso_variants import CLPSO, CSO, DMSPSOEL, FSPSO, PSO, SLPSOGS, SLPSOUS diff --git a/src/evox/algorithms/de_variants/__init__.py b/src/evox/algorithms/de_variants/__init__.py index 65b541ebb..1b275c133 100644 --- a/src/evox/algorithms/de_variants/__init__.py +++ b/src/evox/algorithms/de_variants/__init__.py @@ -1,6 +1,9 @@ -__all__ = ["DE", "ODE", "JaDE"] +__all__ = ["DE", "CoDE", "JaDE", "ODE", "SaDE", "SHADE"] +from .code import CoDE from .de import DE from .jade import JaDE from .ode import ODE +from .sade import SaDE +from .shade import SHADE diff --git a/src/evox/algorithms/de_variants/code.py b/src/evox/algorithms/de_variants/code.py new file mode 100644 index 000000000..2c8f111e7 --- /dev/null +++ b/src/evox/algorithms/de_variants/code.py @@ -0,0 +1,151 @@ +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from ...operators.crossover import ( + DE_arithmetic_recombination, + DE_binary_crossover, + DE_differential_sum, + DE_exponential_crossover, +) +from ...operators.selection import select_rand_pbest +from ...utils import clamp + +""" +Strategy codes(4 bits): [base_vec_prim, base_vec_sec, diff_num, cross_strategy] +base_vec : 0="rand", 1="best", 2="pbest", 3="current" +cross_strategy: 0=bin , 1=exp , 2=arith +""" + +rand_1_bin = [0, 0, 1, 0] +rand_2_bin = [0, 0, 2, 0] +current2rand_1 = [0, 0, 1, 2] # current2rand_1 <==> rand_1_arith +rand2best_2_bin = [0, 1, 2, 0] +current2pbest_1_bin = [3, 2, 1, 0] + + +@jit_class +class CoDE(Algorithm): + """The implementation of CoDE algorithm. + + Reference: + Wang Y, Cai Z, Zhang Q. Differential evolution with composite trial vector generation strategies and control parameters[J]. IEEE transactions on evolutionary computation, 2011, 15(1): 55-66. + """ + + def __init__( + self, + pop_size: int, + lb: torch.Tensor, + ub: torch.Tensor, + diff_padding_num: int = 5, + param_pool: torch.Tensor = torch.tensor([[1, 0.1], [1, 0.9], [0.8, 0.2]]), + replace: bool = False, + device: torch.device | None = None, + ): + """ + Initialize the CoDE algorithm with the given parameters. + + :param pop_size: The size of the population. + :param lb: The lower bounds of the search space. Must be a 1D tensor. + :param ub: The upper bounds of the search space. Must be a 1D tensor. + :param diff_padding_num: The number of differential padding vectors to use. Defaults to 5. + :param param_pool: A tensor of control parameter pairs for the algorithm. Defaults to a predefined tensor. + :param replace: A boolean indicating whether to replace individuals in the population. Defaults to False. + :param device: The device to use for tensor computations. Defaults to None. + """ + super().__init__() + device = torch.get_default_device() if device is None else device + dim = lb.shape[0] + # parameters + self.param_pool = Parameter(param_pool, device=device) + # set value + lb = lb[None, :].to(device=device) + ub = ub[None, :].to(device=device) + self.lb = lb + self.ub = ub + self.dim = dim + self.replace = replace + self.pop_size = pop_size + self.diff_padding_num = diff_padding_num + self.strategies = torch.tensor([rand_1_bin, rand_2_bin, current2rand_1], device=device) + # setup + self.best_index = Mutable(torch.tensor(0, device=device)) + self.population = Mutable(torch.randn(pop_size, dim, device=device) * (ub - lb) + lb) + self.fitness = Mutable(torch.full((self.pop_size,), fill_value=torch.inf, device=device)) + + def step(self): + """Perform one iteration of the CoDE algorithm. + + This step is composed of the following steps: + 1. Generate trial vectors using the differential sum. + 2. Apply crossover to generate a new vector. + 3. Apply mutation to generate a new vector. + 4. Update the population and fitness values. + """ + device = self.population.device + indices = torch.arange(self.pop_size, device=device) + + param_ids = torch.randint(0, 3, (3, self.pop_size), device=device) + + base_vec_prim_type = self.strategies[:, 0] + base_vec_sec_type = self.strategies[:, 1] + num_diff_vectors = self.strategies[:, 2] + cross_strategy = self.strategies[:, 3] + + params = self.param_pool[param_ids] + differential_weight = params[:, :, 0] + cross_probability = params[:, :, 1] + + trial_vectors = torch.zeros((3, self.pop_size, self.dim), device=device) + + for i in range(3): + difference_sum, rand_vec_idx = DE_differential_sum( + self.diff_padding_num, + num_diff_vectors[i], + indices, + self.population, + #self.replace + ) + + rand_vec = self.population[rand_vec_idx] + best_vec = torch.tile(self.population[self.best_index].unsqueeze(0), (self.pop_size, 1)) + pbest_vec = select_rand_pbest(0.05, self.population, self.fitness) + current_vec = self.population[indices] + + vec_merge = torch.stack((rand_vec, best_vec, pbest_vec, current_vec)) + base_vec_prim = vec_merge[base_vec_prim_type[i]] + base_vec_sec = vec_merge[base_vec_sec_type[i]] + + base_vec = base_vec_prim + differential_weight[i].unsqueeze(1) * (base_vec_sec - base_vec_prim) + mutation_vec = base_vec + difference_sum * differential_weight[i].unsqueeze(1) + + trial_vec = torch.zeros(self.pop_size, self.dim, device=device) + trial_vec = torch.where( + cross_strategy[i] == 0, DE_binary_crossover(mutation_vec, current_vec, cross_probability[i]), trial_vec + ) + trial_vec = torch.where( + cross_strategy[i] == 1, DE_exponential_crossover(mutation_vec, current_vec, cross_probability[i]), trial_vec + ) + trial_vec = torch.where( + cross_strategy[i] == 2, DE_arithmetic_recombination(mutation_vec, current_vec, cross_probability[i]), trial_vec + ) + trial_vectors = torch.where( + (torch.arange(3, device=device) == i).unsqueeze(1).unsqueeze(2), trial_vec.unsqueeze(0), trial_vectors + ) + + trial_vectors = clamp(trial_vectors.reshape(3 * self.pop_size, self.dim), self.lb, self.ub) + trial_fitness = self.evaluate(trial_vectors) + + indices = torch.arange(3 * self.pop_size, device=device).reshape(3, self.pop_size) + trans_fit = trial_fitness[indices] + + min_indices = torch.argmin(trans_fit, dim=0) + min_indices_global = indices[min_indices, torch.arange(self.pop_size, device=device)] + + trial_fitness_select = trial_fitness[min_indices_global] + trial_vectors_select = trial_vectors[min_indices_global] + + compare = trial_fitness_select <= self.fitness + + self.population = torch.where(compare[:, None], trial_vectors_select, self.population) + self.fitness = torch.where(compare, trial_fitness_select, self.fitness) + self.best_index = torch.argmin(self.fitness) diff --git a/src/evox/algorithms/de_variants/jade.py b/src/evox/algorithms/de_variants/jade.py index 7cc5fd1be..a3d97bedf 100644 --- a/src/evox/algorithms/de_variants/jade.py +++ b/src/evox/algorithms/de_variants/jade.py @@ -114,9 +114,9 @@ def step(self): [self.population[random_choices[i]] - self.population[random_choices[i + 1]] for i in range(1, num_vec - 1, 2)] ).sum(dim=0) - pbest_vects = self._select_rand_pbest_vects(p=0.05) + pbest_vectors = self._select_rand_pbest_vectors(p=0.05) base_vectors_prim = self.population - base_vectors_sec = pbest_vects + base_vectors_sec = pbest_vectors F_vec_2D = F_vec[:, None] base_vectors = base_vectors_prim + F_vec_2D * (base_vectors_sec - base_vectors_prim) @@ -162,7 +162,7 @@ def step(self): self.F_u = torch.where(count_mask, updated_F_u, self.F_u) self.CR_u = torch.where(count_mask, updated_CR_u, self.CR_u) - def _select_rand_pbest_vects(self, p: float) -> torch.Tensor: + def _select_rand_pbest_vectors(self, p: float) -> torch.Tensor: """ Select p-best vectors from the population for mutation. @@ -181,6 +181,6 @@ def _select_rand_pbest_vects(self, p: float) -> torch.Tensor: pbest_indices = pbest_indices_pool[random_indices] # Retrieve p-best vectors using the sampled indices - pbest_vects = self.population[pbest_indices] + pbest_vectors = self.population[pbest_indices] - return pbest_vects + return pbest_vectors diff --git a/src/evox/algorithms/de_variants/sade.py b/src/evox/algorithms/de_variants/sade.py new file mode 100644 index 000000000..da84476a9 --- /dev/null +++ b/src/evox/algorithms/de_variants/sade.py @@ -0,0 +1,216 @@ +import torch + +from ...core import Algorithm, Mutable, jit_class, vmap_impl +from ...operators.crossover import ( + DE_arithmetic_recombination, + DE_binary_crossover, + DE_differential_sum, + DE_exponential_crossover, +) +from ...operators.selection import select_rand_pbest +from ...utils import clamp + +# Strategy codes(4 bits): [base_vec_prim, base_vec_sec, diff_num, cross_strategy] +# base_vec: 0="rand", 1="best", 2="pbest", 3="current", cross_strategy: 0=bin, 1=exp, 2=arith +rand_1_bin = [0, 0, 1, 0] +rand_2_bin = [0, 0, 2, 0] +rand2best_2_bin = [0, 1, 2, 0] +current2rand_1 = [0, 0, 1, 2] # current2rand_1 <==> rand_1_arith + + +@jit_class +class SaDE(Algorithm): + """The implementation of SaDE algorithm. + + Reference: + Qin A K, Huang V L, Suganthan P N. + Differential evolution algorithm with strategy adaptation for global numerical optimization[J]. + IEEE transactions on Evolutionary Computation, 2008, 13(2): 398-417. + """ + + def __init__( + self, + pop_size: int, + lb: torch.Tensor, + ub: torch.Tensor, + diff_padding_num: int = 9, + LP: int = 50, + device: torch.device | None = None, + ): + """ + Initialize the SaDE algorithm with the given parameters. + + :param pop_size: The size of the population. + :param lb: The lower bounds of the search space. Must be a 1D tensor. + :param ub: The upper bounds of the search space. Must be a 1D tensor. + :param diff_padding_num: The number of differential padding vectors to use. Defaults to 9. + :param LP: The size of memory. Defaults to 50. + :param device: The device to use for tensor computations (e.g., "cpu" or "cuda"). Defaults to None. + """ + super().__init__() + device = torch.get_default_device() if device is None else device + assert pop_size >= 9 + assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1 and lb.dtype == ub.dtype + dim = lb.shape[0] + # parameters + # set value + lb = lb[None, :].to(device=device) + ub = ub[None, :].to(device=device) + self.lb = lb + self.ub = ub + self.LP = LP + self.dim = dim + self.pop_size = pop_size + self.diff_padding_num = diff_padding_num + self.strategy_pool = torch.tensor([rand_1_bin, rand2best_2_bin, rand_2_bin, current2rand_1], device=device) + # setup + self.gen_iter = Mutable(torch.tensor(0, device=device)) + self.best_index = Mutable(torch.tensor(0, device=device)) + self.Memory_FCR = Mutable(torch.full((2, 100), fill_value=0.5, device=device)) + self.population = Mutable(torch.randn(pop_size, dim, device=device) * (ub - lb) + lb) + self.fitness = Mutable(torch.full((self.pop_size,), fill_value=torch.inf, device=device)) + self.success_memory = Mutable(torch.full((LP, 4), fill_value=0, device=device)) + self.failure_memory = Mutable(torch.full((LP, 4), fill_value=0, device=device)) + self.CR_memory = Mutable(torch.full((LP, 4), fill_value=torch.nan, device=device)) + # Others + self.g_cuda = torch.Generator(device="cuda") + self.g_cpu = torch.Generator() + + def _get_strategy_ids(self, strategy_p: torch.Tensor, device: torch.device): + if device.type == "cuda": + generator = self.g_cuda + else: + generator = self.g_cpu + strategy_ids = torch.multinomial(strategy_p, self.pop_size, replacement=True, generator=generator) + return strategy_ids + + @vmap_impl(_get_strategy_ids) + def _vmap_get_strategy_ids(self, strategy_p: torch.Tensor, device: torch.device): + # TODO: since torch.multinomial is not supported in vmap, we have to use torch.randint + strategy_ids = torch.randint(0, 4, (self.pop_size,), device=device) + return strategy_ids + + def step(self): + """ + Execute a single optimization step of the SaDE algorithm. + + This involves the following sub-steps: + 1. Generate new population using differential evolution. + 2. Evaluate the fitness of the new population. + 3. Update the best individual and best fitness. + 4. Update the success and failure memory. + 5. Update the CR memory. + """ + device = self.population.device + indices = torch.arange(self.pop_size, device=device) + + CRM_init = torch.tensor([0.5, 0.5, 0.5, 0.5], device=device) + strategy_p_init = torch.tensor([0.25, 0.25, 0.25, 0.25], device=device) + + success_sum = torch.sum(self.success_memory, dim=0) + failure_sum = torch.sum(self.failure_memory, dim=0) + S_mat = (success_sum / (success_sum + failure_sum)) + 0.01 + + strategy_p_update = S_mat / torch.sum(S_mat) + strategy_p = torch.where(self.gen_iter >= self.LP, strategy_p_update, strategy_p_init) + + CRM_update = torch.median(self.CR_memory, dim=0)[0] + CRM = torch.where(self.gen_iter > self.LP, CRM_update, CRM_init) + + strategy_ids = self._get_strategy_ids(strategy_p, device) + + CRs_vec = torch.randn((self.pop_size, 4), device=device) * 0.1 + CRM + CRs_vec_repair = torch.randn((self.pop_size, 4), device=device) * 0.1 + CRM + + mask = (CRs_vec < 0) | (CRs_vec > 1) + CRs_vec = torch.where(mask, CRs_vec_repair, CRs_vec) + + differential_weight = torch.randn(self.pop_size, device=device) * 0.3 + 0.5 + cross_probability = torch.gather(CRs_vec, 1, strategy_ids[:, None])[:, 0] + + strategy_code = self.strategy_pool[strategy_ids] + base_vec_prim_type = strategy_code[:, 0] + base_vec_sec_type = strategy_code[:, 1] + num_diff_vectors = strategy_code[:, 2] + cross_strategy = strategy_code[:, 3] + + difference_sum, rand_vec_idx = DE_differential_sum(self.diff_padding_num, num_diff_vectors, indices, self.population) + + rand_vec = self.population[rand_vec_idx] + best_vec = torch.tile(self.population[self.best_index].unsqueeze(0), (self.pop_size, 1)) + pbest_vec = select_rand_pbest(0.05, self.population, self.fitness) + current_vec = self.population[indices] + vector_merge = torch.stack([rand_vec, best_vec, pbest_vec, current_vec]) + + base_vector_prim = torch.zeros(self.pop_size, 4, device=device) + base_vector_sec = torch.zeros(self.pop_size, 4, device=device) + + for i in range(4): + base_vector_prim = torch.where(base_vec_prim_type.unsqueeze(1) == i, vector_merge[i], base_vector_prim) + base_vector_sec = torch.where(base_vec_sec_type.unsqueeze(1) == i, vector_merge[i], base_vector_sec) + + base_vector = base_vector_prim + differential_weight.unsqueeze(1) * (base_vector_sec - base_vector_prim) + mutation_vector = base_vector + difference_sum * differential_weight.unsqueeze(1) + + trial_vector = torch.zeros(self.pop_size, self.dim, device=device) + trial_vector = torch.where( + cross_strategy.unsqueeze(1) == 0, + DE_binary_crossover(mutation_vector, current_vec, cross_probability), + trial_vector, + ) + trial_vector = torch.where( + cross_strategy.unsqueeze(1) == 1, + DE_exponential_crossover(mutation_vector, current_vec, cross_probability), + trial_vector, + ) + trial_vector = torch.where( + cross_strategy.unsqueeze(1) == 2, + DE_arithmetic_recombination(mutation_vector, current_vec, cross_probability), + trial_vector, + ) + trial_vector = clamp(trial_vector, self.lb, self.ub) + + CRs_vec = torch.gather(CRs_vec, 1, strategy_ids.unsqueeze(1)).squeeze(1) + + trial_fitness = self.evaluate(trial_vector) + + self.gen_iter = self.gen_iter + 1 + + compare = trial_fitness <= self.fitness + + self.population = torch.where(compare[:, None], trial_vector, self.population) + self.fitness = torch.where(compare, trial_fitness, self.fitness) + + self.best_index = torch.argmin(self.fitness) + + """Update memories""" + success_memory = torch.roll(self.success_memory, 1, 0) + success_memory[0, :] = 0 + failure_memory = torch.roll(self.failure_memory, 1, 0) + failure_memory[0, :] = 0 + + for i in range(self.pop_size): + success_memory_up = success_memory.clone() + success_memory_up[0][strategy_ids[i]] += 1 + success_memory = torch.where(compare[i], success_memory_up, success_memory) + + failure_memory_up = failure_memory.clone() + failure_memory_up[0][strategy_ids[i]] += 1 + failure_memory = torch.where(compare[i], failure_memory, failure_memory_up) + + CR_memory = self.CR_memory + + for i in range(self.pop_size): + str_idx = strategy_ids[i] + + CR_mk_up = torch.roll(CR_memory.t()[str_idx], 1) + CR_mk_up[0] = CRs_vec[i] + + CR_memory_up = CR_memory.clone().t() + CR_memory_up[str_idx][:] = CR_mk_up + CR_memory_up = CR_memory_up.t() + CR_memory = torch.where(compare[i], CR_memory_up, CR_memory) + + self.success_memory = success_memory + self.failure_memory = failure_memory + self.CR_memory = CR_memory diff --git a/src/evox/algorithms/de_variants/shade.py b/src/evox/algorithms/de_variants/shade.py new file mode 100644 index 000000000..690c70c02 --- /dev/null +++ b/src/evox/algorithms/de_variants/shade.py @@ -0,0 +1,149 @@ +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from ...operators.crossover import ( + DE_binary_crossover, + DE_differential_sum, +) +from ...operators.selection import select_rand_pbest +from ...utils import clamp + + +@jit_class +class SHADE(Algorithm): + """The implementation of SHADE algorithm. + + Reference: + Tanabe R, Fukunaga A. + Success-history based parameter adaptation for differential evolution[C]//2013 + IEEE congress on evolutionary computation. IEEE, 2013: 71-78. + """ + + def __init__( + self, + pop_size: int, + lb: torch.Tensor, + ub: torch.Tensor, + diff_padding_num: int = 9, + device: torch.device | None = None, + ): + """ + Initialize the SHADE algorithm with the given parameters. + + :param pop_size: The size of the population. + :param lb: The lower bounds of the search space. Must be a 1D tensor. + :param ub: The upper bounds of the search space. Must be a 1D tensor. + :param diff_padding_num: The number of differential padding vectors to use. Defaults to 9. + :param device: The device to use for tensor computations (e.g., "cpu" or "cuda"). Defaults to None. + """ + super().__init__() + device = torch.get_default_device() if device is None else device + assert pop_size >= 9 + assert lb.shape == ub.shape and lb.ndim == 1 and ub.ndim == 1 and lb.dtype == ub.dtype + dim = lb.shape[0] + # parameters + self.num_diff_vectors = Parameter(torch.tensor(1), device=device) + # set value + lb = lb[None, :].to(device=device) + ub = ub[None, :].to(device=device) + self.lb = lb + self.ub = ub + self.dim = dim + self.pop_size = pop_size + self.diff_padding_num = diff_padding_num + # setup + self.best_index = Mutable(torch.tensor(0, device=device)) + self.Memory_FCR = Mutable(torch.full((2, 100), fill_value=0.5, device=device)) + self.population = Mutable(torch.randn(pop_size, dim, device=device) * (ub - lb) + lb) + self.fitness = Mutable(torch.full((self.pop_size,), fill_value=torch.inf, device=device)) + + def step(self): + """ + Perform a single step of the SHADE algorithm. + + This involves the following sub-steps: + 1. Generate trial vectors using the SHADE algorithm. + 2. Evaluate the fitness of the trial vectors. + 3. Update the population. + 4. Update the memory. + """ + device = self.population.device + indices = torch.arange(self.pop_size, device=device) + + FCR_ids = torch.randperm(self.pop_size) + M_F_vect = self.Memory_FCR[0, FCR_ids] + M_CR_vect = self.Memory_FCR[1, FCR_ids] + + F_vect = torch.randn(self.pop_size, device=device) * 0.1 + M_F_vect + F_vect = clamp(F_vect, torch.zeros(self.pop_size, device=device), torch.ones(self.pop_size, device=device)) + + CR_vect = torch.randn(self.pop_size, device=device) * 0.1 + M_CR_vect + CR_vect = clamp(CR_vect, torch.zeros(self.pop_size, device=device), torch.ones(self.pop_size, device=device)) + + difference_sum, rand_vect_idx = DE_differential_sum( + self.diff_padding_num, torch.tile(self.num_diff_vectors, (self.pop_size,)), indices, self.population + ) + pbest_vect = select_rand_pbest(0.05, self.population, self.fitness) + current_vect = self.population[indices] + + base_vector_prim = current_vect + base_vector_sec = pbest_vect + + base_vector = base_vector_prim + F_vect.unsqueeze(1) * (base_vector_sec - base_vector_prim) + mutation_vector = base_vector + difference_sum * F_vect.unsqueeze(1) + + trial_vector = DE_binary_crossover(mutation_vector, current_vect, CR_vect) + trial_vector = clamp(trial_vector, self.lb, self.ub) + + trial_fitness = self.evaluate(trial_vector) + + compare = trial_fitness < self.fitness + + population_update = torch.where(compare[:, None], trial_vector, self.population) + fitness_update = torch.where(compare, trial_fitness, self.fitness) + + self.population = population_update + self.fitness = fitness_update + + self.best_index = torch.argmin(self.fitness) + + S_F = torch.full((self.pop_size,), fill_value=torch.nan, device=device) + S_CR = torch.full((self.pop_size,), fill_value=torch.nan, device=device) + S_delta = torch.full((self.pop_size,), fill_value=torch.nan, device=device) + + deltas = self.fitness - trial_fitness + + for i in range(self.pop_size): # get_success_delta + is_success = compare[i].float() + F = F_vect[i] + CR = CR_vect[i] + delta = deltas[i] + + S_F_update_temp = torch.roll(S_F, shifts=1) + S_F_update = torch.cat([F.unsqueeze(0), S_F_update_temp[1:]], dim=0) + + S_CR_update_temp = torch.roll(S_CR, shifts=1) + S_CR_update = torch.cat([CR.unsqueeze(0), S_CR_update_temp[1:]], dim=0) + + S_delta_update_temp = torch.roll(S_delta, shifts=1) + S_delta_update = torch.cat([delta.unsqueeze(0), S_delta_update_temp[1:]], dim=0) + + S_F = is_success * S_F_update + (1.0 - is_success) * S_F_update_temp + S_CR = is_success * S_CR_update + (1.0 - is_success) * S_CR_update_temp + S_delta = is_success * S_delta_update + (1.0 - is_success) * S_delta_update_temp + + norm_delta = S_delta / torch.nansum(S_delta) + M_CR = torch.nansum(norm_delta * S_CR) + M_F = torch.nansum(norm_delta * (S_F**2)) / torch.nansum(norm_delta * S_F) + + Memory_FCR_update = torch.roll(self.Memory_FCR, shifts=1, dims=1) + Memory_FCR_update[0, 0] = M_F + Memory_FCR_update[1, 0] = M_CR + + is_F_nan = torch.isnan(M_F) + Memory_FCR_update = torch.where(is_F_nan, self.Memory_FCR, Memory_FCR_update) + + is_S_nan = torch.all(torch.isnan(compare)) + Memory_FCR = torch.where(is_S_nan, self.Memory_FCR, Memory_FCR_update) + + self.Memory_FCR = Memory_FCR diff --git a/src/evox/algorithms/es_variants/__init__.py b/src/evox/algorithms/es_variants/__init__.py index d30049542..2dc26beee 100644 --- a/src/evox/algorithms/es_variants/__init__.py +++ b/src/evox/algorithms/es_variants/__init__.py @@ -1,5 +1,27 @@ -__all__ = ["OpenES", "CMAES"] +__all__ = [ + "OpenES", + "XNES", + "SeparableNES", + "DES", + "SNES", + "ARS", + "ASEBO", + "PersistentES", + "NoiseReuseES", + "GuidedES", + "ESMC", + "CMAES", +] +from .ars import ARS +from .asebo import ASEBO from .cma_es import CMAES +from .des import DES +from .esmc import ESMC +from .guided_es import GuidedES +from .nes import XNES, SeparableNES +from .noise_reuse_es import NoiseReuseES from .open_es import OpenES +from .persistent_es import PersistentES +from .snes import SNES diff --git a/src/evox/algorithms/es_variants/ars.py b/src/evox/algorithms/es_variants/ars.py new file mode 100644 index 000000000..7168be93d --- /dev/null +++ b/src/evox/algorithms/es_variants/ars.py @@ -0,0 +1,97 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class ARS(Algorithm): + """The implementation of the ARS algorithm. + + Reference: + Simple random search provides a competitive approach to reinforcement learning (https://arxiv.org/pdf/1803.07055.pdf) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + elite_ratio: float = 0.1, + lr: float = 0.05, + sigma: float = 0.03, + optimizer: Literal["adam"] | None = None, + device: torch.device | None = None, + ): + """Initialize the ARS algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param elite_ratio: The ratio of elite population. Defaults to 0.1. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param sigma: The standard deviation of the noise. Defaults to 0.03. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + assert 0 <= elite_ratio <= 1 + dim = center_init.shape[0] + # set hyperparameters + self.lr = Parameter(lr, device=device) + self.sigma = Parameter(sigma, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + self.elite_pop_size = max(1, int(pop_size / 2 * elite_ratio)) + # setup + center_init = center_init.to(device=device) + self.center = Mutable(center_init) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + """Perform a single step of the ARS algorithm.""" + device = self.center.device + + z_plus = torch.randn(int(self.pop_size / 2), self.dim, device=device) + noise = torch.cat([z_plus, -1.0 * z_plus]) + population = self.center + self.sigma * noise + + fitness = self.evaluate(population) + + noise_1 = noise[: int(self.pop_size / 2)] + fit_1 = fitness[: int(self.pop_size / 2)] + fit_2 = fitness[int(self.pop_size / 2) :] + elite_idx = torch.minimum(fit_1, fit_2).argsort()[: self.elite_pop_size] + + fitness_elite = torch.cat([fit_1[elite_idx], fit_2[elite_idx]]) + sigma_fitness = torch.std(fitness_elite) + 1e-05 + + fit_diff = fit_1[elite_idx] - fit_2[elite_idx] + fit_diff_noise = noise_1[elite_idx].T @ fit_diff + + theta_grad = 1.0 / (self.elite_pop_size * sigma_fitness) * fit_diff_noise + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + self.center = center diff --git a/src/evox/algorithms/es_variants/asebo.py b/src/evox/algorithms/es_variants/asebo.py new file mode 100644 index 000000000..52dd90c40 --- /dev/null +++ b/src/evox/algorithms/es_variants/asebo.py @@ -0,0 +1,156 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class ASEBO(Algorithm): + """The implementation of the ASEBO algorithm. + + Reference: + From Complexity to Simplicity: Adaptive ES-Active Subspaces for Blackbox Optimization + (https://arxiv.org/abs/1903.04268) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + optimizer: Literal["adam"] | None = None, + lr: float = 0.05, + lr_decay: float = 1.0, + lr_limit: float = 0.001, + sigma: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + subspace_dims: int | None = None, + device: torch.device | None = None, + ): + """Initialize the ARS algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param lr_decay: The decay factor for the learning rate. Defaults to 1.0. + :param lr_limit: The minimum value for the learning rate. Defaults to 0.001. + :param sigma: The standard deviation of the noise. Defaults to 0.03. + :param sigma_decay: The decay factor for the standard deviation. Defaults to 1.0. + :param sigma_limit: The minimum value for the standard deviation. Defaults to 0.01. + :param subspace_dims: The dimension of the subspace. Defaults to None. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + dim = center_init.shape[0] + if subspace_dims is None: + subspace_dims = dim + # set hyperparameters + self.lr = Parameter(lr, device=device) + self.lr_decay = Parameter(lr_decay, device=device) + self.lr_limit = Parameter(lr_limit, device=device) + self.sigma_decay = Parameter(sigma_decay, device=device) + self.sigma_limit = Parameter(sigma_limit, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + self.subspace_dims = subspace_dims + # setup + center_init.to(device=device) + self.center = Mutable(center_init) + self.grad_subspace = Mutable(torch.zeros(self.subspace_dims, self.dim, device=device)) + self.UUT = Mutable(torch.zeros(self.dim, self.dim, device=device)) + self.UUT_ort = Mutable(torch.zeros(self.dim, self.dim, device=device)) + self.sigma = Mutable(torch.tensor(sigma, device=device)) + self.alpha = Mutable(torch.tensor(0.1, device=device)) + self.gen_counter = Mutable(torch.tensor(0.0, device=device)) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + """ + The main step of the ASEBO algorithm. + + This function first computes the subspace spanned by the gradient of the fitness function + and then projects the gradient onto the subspace. It then computes the step direction + using the projected gradient and updates the center and standard deviation of the + search distribution. + """ + device = self.center.device + + X = self.grad_subspace + X = X - torch.mean(X, dim=0) + U, S, Vt = torch.svd(X, some=True) + + max_abs_cols = torch.argmax(torch.abs(U), dim=0) + signs = torch.sign(U[max_abs_cols, :]) + U = U * signs + Vt = Vt * signs + + U = Vt[: int(self.pop_size / 2)] + UUT = torch.matmul(U.T, U) + U_ort = Vt[int(self.pop_size / 2) :] + UUT_ort = torch.matmul(U_ort.T, U_ort) + + UUT = torch.where(self.gen_counter > self.subspace_dims, UUT, torch.zeros(self.dim, self.dim, device=device)) + + cov = ( + self.sigma * (self.alpha / self.dim) * torch.eye(self.dim, device=device) + + ((1 - self.alpha) / int(self.pop_size / 2)) * UUT + ) + chol = torch.linalg.cholesky(cov) + noise = torch.randn(self.dim, int(self.pop_size / 2), device=device) + + z_plus = torch.swapaxes(chol @ noise, 0, 1) + z_plus = z_plus / torch.linalg.norm(z_plus, dim=-1)[:, None] + z = torch.cat([z_plus, -1.0 * z_plus]) + + population = self.center + z + + self.gen_counter = self.gen_counter + 1 + + fitness = self.evaluate(population) + + noise = (population - self.center) / self.sigma + noise_1 = noise[: int(self.pop_size / 2)] + fit_1 = fitness[: int(self.pop_size / 2)] + fit_2 = fitness[int(self.pop_size / 2) :] + fit_diff_noise = noise_1.T @ (fit_1 - fit_2) + + theta_grad = 1.0 / 2.0 * fit_diff_noise + alpha = torch.linalg.norm(theta_grad @ UUT_ort) / torch.linalg.norm(theta_grad @ self.UUT) + + alpha = torch.where(self.gen_counter > self.subspace_dims, alpha, 1.0) + + self.grad_subspace = torch.cat([self.grad_subspace, theta_grad[None, :]])[1:, :] + theta_grad /= torch.linalg.norm(theta_grad) / self.dim + 1e-8 + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + sigma = self.sigma * self.sigma_decay + sigma = torch.maximum(sigma, self.sigma_limit) + + self.center = center + self.sigma = sigma + self.alpha = alpha diff --git a/src/evox/algorithms/es_variants/cma_es.py b/src/evox/algorithms/es_variants/cma_es.py index f5b8a3ef9..67aca3560 100644 --- a/src/evox/algorithms/es_variants/cma_es.py +++ b/src/evox/algorithms/es_variants/cma_es.py @@ -25,6 +25,7 @@ def __init__( :param pop_size: The size of the population with the notation $\\lambda$. :param mean_init: The initial mean of the population. Must be a 1D tensor. :param sigma: The standard deviation of the noise. + :param weights: The recombination weights of the population. Defaults to None. :param device: The device to use for the tensors. Defaults to None. """ super().__init__() diff --git a/src/evox/algorithms/es_variants/des.py b/src/evox/algorithms/es_variants/des.py new file mode 100644 index 000000000..4afda1ece --- /dev/null +++ b/src/evox/algorithms/es_variants/des.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F + +from ...core import Algorithm, Mutable, Parameter, jit_class + + +@jit_class +class DES(Algorithm): + """The implementation of the DES algorithm. + + Reference: + Discovering Evolution Strategies via Meta-Black-Box Optimization + (https://arxiv.org/abs/2211.11260) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + temperature: float = 12.5, + sigma_init: float = 0.1, + device: torch.device | None = None, + ): + """Initialize the DES algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param temperature: The temperature parameter for the softmax. Defaults to 12.5. + :param sigma_init: The initial standard deviation of the noise. Defaults to 0.1. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + dim = center_init.shape[0] + # set hyperparameters + self.temperature = Parameter(temperature, device=device) + self.sigma_init = Parameter(sigma_init, device=device) + self.lrate_mean = Parameter(1.0, device=device) + self.lrate_sigma = Parameter(0.1, device=device) + # set value + ranks = torch.arange(pop_size, device=device) / (pop_size - 1) - 0.5 + self.dim = dim + self.ranks = ranks + self.pop_size = pop_size + # setup + center_init = center_init.to(device=device) + self.center = Mutable(center_init) + self.sigma = Mutable(sigma_init * torch.ones(self.dim, device=device)) + + def step(self): + """Step the DES algorithm by sampling the population, evaluating the fitness, and updating the center.""" + device = self.center.device + + noise = torch.randn(self.pop_size, self.dim, device=device) + population = self.center + noise * self.sigma + + fitness = self.evaluate(population) + + population = population[fitness.argsort()] + + weight = F.softmax(-20 * F.sigmoid(self.temperature * self.ranks), dim=0) + weight = torch.tile(weight[:, None], (1, self.dim)) + + weight_mean = (weight * population).sum(dim=0) + weight_sigma = torch.sqrt((weight * (population - self.center) ** 2).sum(dim=0) + 1e-6) + + center = self.center + self.lrate_mean * (weight_mean - self.center) + sigma = self.sigma + self.lrate_sigma * (weight_sigma - self.sigma) + + self.center = center + self.sigma = sigma diff --git a/src/evox/algorithms/es_variants/esmc.py b/src/evox/algorithms/es_variants/esmc.py new file mode 100644 index 000000000..581986f37 --- /dev/null +++ b/src/evox/algorithms/es_variants/esmc.py @@ -0,0 +1,108 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class ESMC(Algorithm): + """The implementation of the DES algorithm. + + Reference: + Learn2Hop: Learned Optimization on Rough Landscapes + (https://proceedings.mlr.press/v139/merchant21a.html) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + optimizer: Literal["adam"] | None = None, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + lr: float = 0.05, + sigma: float = 0.03, + device: torch.device | None = None, + ): + """Initialize the ESMC algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param elite_ratio: The ratio of elite population. Defaults to 0.1. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param sigma_decay: The decay factor for the standard deviation. Defaults to 1.0. + :param sigma_limit: The minimum value for the standard deviation. Defaults to 0.01. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + dim = center_init.shape[0] + # set hyperparameters + self.lr = Parameter(lr, device=device) + self.sigma_decay = Parameter(sigma_decay, device=device) + self.sigma_limit = Parameter(sigma_limit, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + # setup + center_init = center_init.to(device=device) + self.center = Mutable(center_init) + self.sigma = Mutable(torch.ones(self.dim, device=device) * sigma) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + """One iteration of the ESMC algorithm. + + This function will sample a population, evaluate their fitness, and then + update the center and standard deviation of the algorithm using the + sampled population. + """ + device = self.center.device + + z_plus = torch.randn(int(self.pop_size / 2), self.dim, device=device) + z = torch.cat([torch.zeros(1, self.dim, device=device), z_plus, -1.0 * z_plus]) + + population = self.center + z * self.sigma.reshape(1, self.dim) + + fitness = self.evaluate(population) + + noise = (population - self.center) / self.sigma + bline_fitness = fitness[0] + noise = noise[1:] + fitness = fitness[1:] + noise_1 = noise[: int((self.pop_size - 1) / 2)] + fit_1 = fitness[: int((self.pop_size - 1) / 2)] + fit_2 = fitness[int((self.pop_size - 1) / 2) :] + fit_diff = torch.minimum(fit_1, bline_fitness) - torch.minimum(fit_2, bline_fitness) + fit_diff_noise = noise_1.T @ fit_diff + + theta_grad = 1.0 / int((self.pop_size - 1) / 2) * fit_diff_noise + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + self.center = center + + sigma = torch.maximum(self.sigma * self.sigma_decay, self.sigma_limit) + self.sigma = sigma diff --git a/src/evox/algorithms/es_variants/guided_es.py b/src/evox/algorithms/es_variants/guided_es.py new file mode 100644 index 000000000..6bb02e427 --- /dev/null +++ b/src/evox/algorithms/es_variants/guided_es.py @@ -0,0 +1,121 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class GuidedES(Algorithm): + """The implementation of the Guided-ES algorithm. + + Reference: + Guided evolutionary strategies: Augmenting random search with surrogate gradients + (https://arxiv.org/abs/1806.10230) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + subspace_dims: int | None = None, + optimizer: Literal["adam"] | None = None, + sigma: float = 0.03, + lr: float = 60, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + device: torch.device | None = None, + ): + """Initialize the Guided-ES algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param sigma: The standard deviation of the noise. Defaults to 0.03. + :param sigma_decay: The decay factor for the standard deviation. Defaults to 1.0. + :param sigma_limit: The minimum value for the standard deviation. Defaults to 0.01. + :param subspace_dims: The dimension of the subspace. Defaults to None. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 and pop_size % 2 == 0 + dim = center_init.shape[0] + if subspace_dims is None: + subspace_dims = dim + # set hyperparameters + self.beta = Parameter(1.0, device=device) + self.lr = Parameter(lr, device=device) + self.sigma_decay = Parameter(sigma_decay, device=device) + self.sigma_limit = Parameter(sigma_limit, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + self.subspace_dims = subspace_dims + # setup + center_init = center_init.to(device=device) + self.center = Mutable(center_init) + self.alpha = Mutable(torch.tensor(0.5, device=device)) + self.sigma = Mutable(torch.tensor(sigma, device=device)) + self.grad_subspace = Mutable(torch.randn(subspace_dims, dim, device=device)) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + """Run one step of the Guided-ES algorithm. + + The function will sample a population, evaluate their fitness, and then + update the center and standard deviation of the algorithm using the + sampled population. + """ + device = self.center.device + + a = self.sigma * torch.sqrt(self.alpha / self.dim) + c = self.sigma * torch.sqrt((1.0 - self.alpha) / self.subspace_dims) + eps_full = torch.randn(self.dim, int(self.pop_size // 2), device=device) + + eps_subspace = torch.randn(self.subspace_dims, int(self.pop_size // 2), device=device) + Q, _ = torch.linalg.qr(self.grad_subspace) + + z_plus = a * eps_full + c * (Q @ eps_subspace) + z_plus = torch.swapaxes(z_plus, 0, 1) + z = torch.cat([z_plus, -1.0 * z_plus]) + population = self.center + z + + fitness = self.evaluate(population) + + noise = z / self.sigma + noise_1 = noise[: int(self.pop_size / 2)] + fit_1 = fitness[: int(self.pop_size / 2)] + fit_2 = fitness[int(self.pop_size / 2) :] + fit_diff = fit_1 - fit_2 + fit_diff_noise = noise_1.T @ fit_diff + theta_grad = (self.beta / self.pop_size) * fit_diff_noise + + self.grad_subspace = torch.cat([self.grad_subspace, theta_grad[None, :]])[1:, :] + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + self.center = center + + sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit) + self.sigma = sigma diff --git a/src/evox/algorithms/es_variants/nes.py b/src/evox/algorithms/es_variants/nes.py new file mode 100644 index 000000000..b8285ee0c --- /dev/null +++ b/src/evox/algorithms/es_variants/nes.py @@ -0,0 +1,206 @@ +import math + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class + + +@jit_class +class XNES(Algorithm): + """The implementation of the xNES algorithm. + + Reference: + Exponential Natural Evolution Strategies + (https://dl.acm.org/doi/abs/10.1145/1830483.1830557) + """ + def __init__( + self, + init_mean: torch.Tensor, + init_covar: torch.Tensor, + pop_size: int | None = None, + recombination_weights: torch.Tensor | None = None, + learning_rate_mean: float | None = None, + learning_rate_var: float | None = None, + learning_rate_B: float | None = None, + covar_as_cholesky: bool = False, + device: torch.device | None = None, + ): + """Initialize the xNES algorithm with the given parameters. + + :param init_mean: The initial mean vector of the population. Must be a 1D tensor. + :param init_covar: The initial covariance matrix of the population. Must be a 2D tensor. + :param pop_size: The size of the population. Defaults to None. + :param recombination_weights: The recombination weights of the population. Defaults to None. + :param learning_rate_mean: The learning rate for the mean vector. Defaults to None. + :param learning_rate_var: The learning rate for the variance vector. Defaults to None. + :param learning_rate_B: The learning rate for the B matrix. Defaults to None. + :param covar_as_cholesky: Whether to use the covariance matrix as a Cholesky factorization result. Defaults to False. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + dim = init_mean.shape[0] + if pop_size is None: + pop_size = 4 + math.floor(3 * math.log(self.dim)) + assert pop_size > 0 + + if learning_rate_mean is None: + learning_rate_mean = 1 + if learning_rate_var is None: + learning_rate_var = (9 + 3 * math.log(dim)) / 5 / math.pow(dim, 1.5) + if learning_rate_B is None: + learning_rate_B = learning_rate_var + assert learning_rate_mean > 0 and learning_rate_var > 0 and learning_rate_B > 0 + + if not covar_as_cholesky: + init_covar = torch.linalg.cholesky(init_covar) + + if recombination_weights is None: + recombination_weights = torch.arange(1, pop_size + 1) + recombination_weights = torch.clip(math.log(pop_size / 2 + 1) - torch.log(recombination_weights), 0) + recombination_weights = recombination_weights / torch.sum(recombination_weights) - 1 / pop_size + assert ( + recombination_weights[1:] <= recombination_weights[:-1] + ).all(), "recombination_weights must be in descending order" + + # set hyperparameters + self.learning_rate_mean = Parameter(learning_rate_mean, device=device) + self.learning_rate_var = Parameter(learning_rate_var, device=device) + self.learning_rate_B = Parameter(learning_rate_B, device=device) + # set value + recombination_weights = recombination_weights.to(device=device) + self.dim = dim + self.pop_size = pop_size + self.recombination_weights = recombination_weights + # setup + init_mean = init_mean.to(device=device) + init_covar = init_covar.to(device=device) + sigma = torch.pow(torch.prod(torch.diag(init_covar)), 1 / self.dim) + self.sigma = Mutable(sigma) + self.mean = Mutable(init_mean) + self.B = Mutable(init_covar / sigma) + + def step(self): + """Run one step of the xNES algorithm. + + The function will sample a population, evaluate their fitness, and then + update the center and covariance of the algorithm using the sampled + population. + """ + pass + device = self.mean.device + + noise = torch.randn(self.pop_size, self.dim, device=device) + population = self.mean + self.sigma * (noise @ self.B.T) + + fitness = self.evaluate(population) + + order = torch.argsort(fitness) + fitness, noise = fitness[order], noise[order] + + weights = self.recombination_weights + + Ind = torch.eye(self.dim, device=device) + + grad_delta = torch.sum(weights[:, None] * noise, dim=0) + grad_M = (weights * noise.T) @ noise - torch.sum(weights) * Ind + grad_sigma = torch.trace(grad_M) / self.dim + grad_B = grad_M - grad_sigma * Ind + + mean = self.mean + self.learning_rate_mean * self.sigma * self.B @ grad_delta + sigma = self.sigma * torch.exp(self.learning_rate_var / 2 * grad_sigma) + B = self.B @ torch.linalg.matrix_exp(self.learning_rate_B / 2 * grad_B) + + self.sigma = sigma + self.mean = mean + self.B = B + + +@jit_class +class SeparableNES(Algorithm): + """The implementation of the Separable NES algorithm. + + Reference: + Natural Evolution Strategies + (https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) + """ + def __init__( + self, + init_mean: torch.Tensor, + init_std: torch.Tensor, + pop_size: int | None = None, + recombination_weights: torch.Tensor | None = None, + learning_rate_mean: float | None = None, + learning_rate_var: float | None = None, + device: torch.device | None = None, + ): + """Initialize the Separable NES algorithm with the given parameters. + + :param init_mean: The initial mean vector of the population. Must be a 1D tensor. + :param init_std: The initial standard deviation for each dimension. Must be a 1D tensor. + :param pop_size: The size of the population. Defaults to None. + :param recombination_weights: The recombination weights of the population. Defaults to None. + :param learning_rate_mean: The learning rate for the mean vector. Defaults to None. + :param learning_rate_var: The learning rate for the variance vector. Defaults to None. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + dim = init_mean.shape[0] + assert init_std.shape == (dim,) + + if pop_size is None: + pop_size = 4 + math.floor(3 * math.log(self.dim)) + assert pop_size > 0 + + if learning_rate_mean is None: + learning_rate_mean = 1 + if learning_rate_var is None: + learning_rate_var = (3 + math.log(dim)) / 5 / math.sqrt(dim) + assert learning_rate_mean > 0 and learning_rate_var > 0 + + if recombination_weights is None: + recombination_weights = torch.arange(1, pop_size + 1) + recombination_weights = torch.clip(math.log(pop_size / 2 + 1) - torch.log(recombination_weights), 0) + recombination_weights = recombination_weights / torch.sum(recombination_weights) - 1 / pop_size + assert recombination_weights.shape == (pop_size,) + + # set hyperparameters + self.learning_rate_mean = Parameter(learning_rate_mean, device=device) + self.learning_rate_var = Parameter(learning_rate_var, device=device) + # set value + recombination_weights = recombination_weights.to(device=device) + self.dim = dim + self.pop_size = pop_size + self.weight = recombination_weights + # setup + init_std = init_std.to(device=device) + init_mean = init_mean.to(device=device) + self.mean = Mutable(init_mean) + self.sigma = Mutable(init_std) + + def step(self): + """Run one step of the Separable NES algorithm. + + The function will sample a population, evaluate their fitness, and then + update the center and covariance of the algorithm using the sampled + population. + """ + device = self.mean.device + + zero_mean_pop = torch.randn(self.pop_size, self.dim, device=device) + population = self.mean + zero_mean_pop * self.sigma + + fitness = self.evaluate(population) + + order = torch.argsort(fitness) + fitness, population, zero_mean_pop = fitness[order], population[order], zero_mean_pop[order] + + weight = torch.tile(self.weight[:, None], (1, self.dim)) + + grad_μ = torch.sum(weight * zero_mean_pop, dim=0) + grad_sigma = torch.sum(weight * (zero_mean_pop * zero_mean_pop - 1), dim=0) + + mean = self.mean + self.learning_rate_mean * self.sigma * grad_μ + sigma = self.sigma * torch.exp(self.learning_rate_var / 2 * grad_sigma) + + self.mean = mean + self.sigma = sigma diff --git a/src/evox/algorithms/es_variants/noise_reuse_es.py b/src/evox/algorithms/es_variants/noise_reuse_es.py new file mode 100644 index 000000000..ab6b8ad78 --- /dev/null +++ b/src/evox/algorithms/es_variants/noise_reuse_es.py @@ -0,0 +1,116 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class NoiseReuseES(Algorithm): + """The implementation of the Noise-Reuse-ES algorithm. + + Reference: + Noise-Reuse in Online Evolution Strategies + (https://arxiv.org/pdf/2304.12180.pdf) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + optimizer: Literal["adam"] | None = None, + lr: float = 0.05, + sigma: float = 0.03, + T: int = 100, # inner problem length + K: int = 10, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + device: torch.device | None = None, + ): + """Initialize the Guided-ES algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param sigma: The standard deviation of the noise. Defaults to 0.03. + :param sigma_decay: The decay factor for the standard deviation. Defaults to 1.0. + :param sigma_limit: The minimum value for the standard deviation. Defaults to 0.01. + :param T: The inner problem length. Defaults to 100. + :param K: The number of inner problems. Defaults to 10. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + dim = center_init.shape[0] + # set hyperparameters + self.lr = Parameter(lr, device=device) + self.T = Parameter(T, device=device) + self.K = Parameter(K, device=device) + self.sigma_decay = Parameter(sigma_decay, device=device) + self.sigma_limit = Parameter(sigma_limit, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + # setup + center_init = center_init.to(device=device) + self.center = Mutable(center_init) + self.sigma = Mutable(torch.tensor(sigma)) + self.inner_step_counter = Mutable(torch.tensor(0.0, device=device)) + self.unroll_pert = Mutable(torch.zeros(pop_size, self.dim, device=device)) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + """ + Take a single step of the NoiseReuseES algorithm. + + This function follows the algorithm described in the reference paper. + It first generates a set of perturbations for the current population. + Then, it evaluates the fitness of the population with the perturbations. + Afterwards, it calculates the gradient of the policy parameters using the + perturbations and the fitness. + Finally, it updates the policy parameters using the gradient and the + learning rate. + """ + device = self.center.device + + position_perturbations = torch.randn(self.pop_size // 2, self.dim, device=device) * self.sigma + negative_perturbations = -position_perturbations + perturbations = torch.cat([position_perturbations, negative_perturbations], dim=0) + unroll_pert = torch.where(self.inner_step_counter == 0, perturbations, self.unroll_pert) + + population = self.center + unroll_pert + + fitness = self.evaluate(population) + + theta_grad = torch.mean(unroll_pert * fitness.reshape(-1, 1) / (self.sigma**2), dim=0) + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + self.center = center + + inner_step_counter = torch.where(self.inner_step_counter + self.K >= self.T, 0, self.inner_step_counter + self.K) + self.inner_step_counter = inner_step_counter + + sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit) + self.sigma = sigma diff --git a/src/evox/algorithms/es_variants/open_es.py b/src/evox/algorithms/es_variants/open_es.py index c3db77738..16c3e63d2 100644 --- a/src/evox/algorithms/es_variants/open_es.py +++ b/src/evox/algorithms/es_variants/open_es.py @@ -21,7 +21,7 @@ def __init__( mirrored_sampling: bool = True, device: torch.device | None = None, ): - """Initialize the PSO algorithm with the given parameters. + """Initialize the OpenES algorithm with the given parameters. :param pop_size: The size of the population. :param center_init: The initial center of the population. Must be a 1D tensor. @@ -58,6 +58,7 @@ def __init__( self.beta2 = Parameter(0.999, device=device) def step(self): + """Step the OpenES algorithm by evaluating the fitness of the current population and updating the center.""" device = self.center.device if self.mirrored_sampling: noise = torch.randn(self.pop_size // 2, self.dim, device=device) diff --git a/src/evox/algorithms/es_variants/persistent_es.py b/src/evox/algorithms/es_variants/persistent_es.py new file mode 100644 index 000000000..10f09777b --- /dev/null +++ b/src/evox/algorithms/es_variants/persistent_es.py @@ -0,0 +1,111 @@ +from typing import Literal + +import torch + +from ...core import Algorithm, Mutable, Parameter, jit_class +from .adam_step import adam_single_tensor + + +@jit_class +class PersistentES(Algorithm): + """The implementation of the Persistent ES algorithm. + + Reference: + Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies + (http://proceedings.mlr.press/v139/vicol21a.html) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + optimizer: Literal["adam"] | None = None, + lr: float = 0.05, + sigma: float = 0.03, + T: int = 100, + K: int = 10, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + device: torch.device | None = None, + ): + """Initialize the Persistent-ES algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param lr: The learning rate for the optimizer. Defaults to 0.05. + :param sigma: The standard deviation of the noise. Defaults to 0.03. + :param sigma_decay: The decay factor for the standard deviation. Defaults to 1.0. + :param sigma_limit: The minimum value for the standard deviation. Defaults to 0.01. + :param T: The inner problem length. Defaults to 100. + :param K: The number of inner problems. Defaults to 10. + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 and pop_size % 2 == 0 # Population size must be even + dim = center_init.shape[0] + # set hyperparameters + self.lr = Parameter(lr, device=device) + self.T = Parameter(T, device=device) + self.K = Parameter(K, device=device) + self.sigma_decay = Parameter(sigma_decay, device=device) + self.sigma_limit = Parameter(sigma_limit, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + self.optimizer = optimizer + # setup + center_init = center_init.to(device=device) + self.sigma = Mutable(torch.tensor(sigma)) + self.center = Mutable(center_init) + self.inner_step_counter = Mutable(torch.tensor(0.0)) + self.pert_accum = Mutable(torch.zeros(pop_size, dim, device=device)) + + if optimizer == "adam": + self.exp_avg = Mutable(torch.zeros_like(self.center)) + self.exp_avg_sq = Mutable(torch.zeros_like(self.center)) + self.beta1 = Parameter(0.9, device=device) + self.beta2 = Parameter(0.999, device=device) + + def step(self): + device = self.center.device + + pos_perts = torch.randn(self.pop_size // 2, self.dim, device=device) * self.sigma + neg_perts = -pos_perts + perts = torch.cat([pos_perts, neg_perts], dim=0) + pert_accum = self.pert_accum + perts + population = self.center + perts + + fitness = self.evaluate(population) + + theta_grad = torch.mean(pert_accum * fitness.reshape(-1, 1) / (self.sigma**2), dim=0) + + if self.optimizer is None: + center = self.center - self.lr * theta_grad + else: + center, self.exp_avg, self.exp_avg_sq = adam_single_tensor( + self.center, + theta_grad, + self.exp_avg, + self.exp_avg_sq, + self.beta1, + self.beta2, + self.lr, + ) + self.center = center + + inner_step_counter = self.inner_step_counter + self.K + self.inner_step_counter = inner_step_counter + + reset = self.inner_step_counter >= self.T + inner_step_counter = torch.where(reset, 0, inner_step_counter) + pert_accum = torch.where(reset, torch.zeros(self.pop_size, self.dim, device=device), pert_accum) + + sigma = self.sigma_decay * self.sigma + sigma = torch.maximum(sigma, self.sigma_limit) + + self.sigma = sigma + self.pert_accum = pert_accum diff --git a/src/evox/algorithms/es_variants/snes.py b/src/evox/algorithms/es_variants/snes.py new file mode 100644 index 000000000..380576c1d --- /dev/null +++ b/src/evox/algorithms/es_variants/snes.py @@ -0,0 +1,94 @@ +import math +from typing import Literal + +import torch +import torch.nn.functional as F + +from ...core import Algorithm, Mutable, Parameter, jit_class + + +@jit_class +class SNES(Algorithm): + """The implementation of the SNES algorithm. + + Reference: + Natural Evolution Strategies + (https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) + + This code has been inspired by or utilizes the algorithmic implementation from evosax. + More information about evosax can be found at the following URL: + GitHub Link: https://github.com/RobertTLange/evosax + """ + + def __init__( + self, + pop_size: int, + center_init: torch.Tensor, + sigma: float = 1.0, + lrate_mean: float = 1.0, + temperature: float = 12.5, + weight_type: Literal["recomb", "temp"] = "temp", + device: torch.device | None = None, + ): + """Initialize the SNES algorithm with the given parameters. + + :param pop_size: The size of the population. + :param center_init: The initial center of the population. Must be a 1D tensor. + :param optimizer: The optimizer to use. Defaults to None. Currently, only "adam" or None is supported. + :param lrate_mean: The learning rate for the mean. Defaults to 1.0. + :param sigma: The standard deviation of the noise. Defaults to 1.0. + :param temperature: The temperature of the softmax in computing weights. Defaults to 12.5. + :param weight_type: The type of weights to use. Defaults to "temp". + :param device: The device to use for the tensors. Defaults to None. + """ + super().__init__() + assert pop_size > 1 + dim = center_init.shape[0] + # set hyperparameters + lrate_sigma = (3 + math.log(dim)) / (5 * math.sqrt(dim)) + self.lrate_mean = Parameter(lrate_mean, device=device) + self.lrate_sigma = Parameter(lrate_sigma, device=device) + self.temperature = Parameter(temperature, device=device) + # set value + self.dim = dim + self.pop_size = pop_size + # setup + center_init = center_init.to(device=device) + + if weight_type == "temp": + weights = torch.arange(pop_size, device=device) / (pop_size - 1) - 0.5 + weights = F.softmax(-20 * F.sigmoid(temperature * weights), dim=0) + if weight_type == "recomb": + weights = torch.clip(math.log(pop_size / 2 + 1) - torch.log(torch.arange(1, pop_size + 1, device=device)), 0) + weights = weights / torch.sum(weights) - 1 / pop_size + + weights = torch.tile(weights[:, None], (1, self.dim)) + + self.weights = Mutable(weights, device=device) + self.center = Mutable(center_init) + self.sigma = Mutable(sigma * torch.ones(self.dim, device=device)) + + def step(self): + """Run one step of the SNES algorithm. + + The function will sample a population, evaluate their fitness, and then + update the center and standard deviation of the algorithm using the + sampled population. + """ + device = self.center.device + + noise = torch.randn(self.pop_size, self.dim, device=device) + population = self.center + noise * self.sigma.reshape(1, self.dim) + + fitness = self.evaluate(population) + + order = fitness.argsort() + sorted_noise = noise[order] + grad_mean = (self.weights * sorted_noise).sum(dim=0) + grad_sigma = (self.weights * (sorted_noise**2 - 1)).sum(dim=0) + + center = self.center + self.lrate_mean * self.sigma * grad_mean + sigma = self.sigma * torch.exp(self.lrate_sigma / 2 * grad_sigma) + + self.center = center + self.sigma = sigma diff --git a/src/evox/operators/crossover/__init__.py b/src/evox/operators/crossover/__init__.py index 5b0725a76..ce1ed39e2 100644 --- a/src/evox/operators/crossover/__init__.py +++ b/src/evox/operators/crossover/__init__.py @@ -1,4 +1,17 @@ -__all__ = ["simulated_binary", "simulated_binary_half"] +__all__ = [ + "DE_differential_sum", + "DE_exponential_crossover", + "DE_binary_crossover", + "DE_arithmetic_recombination", + "simulated_binary", + "simulated_binary_half", +] +from .differential_evolution import ( + DE_arithmetic_recombination, + DE_binary_crossover, + DE_differential_sum, + DE_exponential_crossover, +) from .sbx import simulated_binary from .sbx_half import simulated_binary_half diff --git a/src/evox/operators/crossover/differential_evolution.py b/src/evox/operators/crossover/differential_evolution.py new file mode 100644 index 000000000..c7711fda9 --- /dev/null +++ b/src/evox/operators/crossover/differential_evolution.py @@ -0,0 +1,96 @@ +from typing import Tuple + +import torch + +from ...utils import minimum_int + + +def DE_differential_sum( + diff_padding_num: int, num_diff_vectors: torch.Tensor, index: torch.Tensor, population: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the difference vectors' sum in differential evolution. + + :param diff_padding_num: The number of padding difference vectors. + :param num_diff_vectors: The number of difference vectors used in mutation. + :param index: The index of current individual. + :param population: The population tensor. + + :return: The difference sum and the index of first difference vector. + """ + device = population.device + pop_size = population.size(0) + if num_diff_vectors.ndim == 0: + num_diff_vectors = num_diff_vectors[None].expand(pop_size) + + select_len = num_diff_vectors.unsqueeze(1) * 2 + 1 + rand_indices = torch.randint(0, pop_size, (pop_size, diff_padding_num), device=device) + rand_indices = torch.where(rand_indices == index.unsqueeze(1), torch.tensor(pop_size - 1, device=device), rand_indices) + + pop_permute = population[rand_indices] + mask = torch.tile(torch.arange(diff_padding_num, device=device), (pop_size, 1)) < select_len + pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, torch.zeros_like(pop_permute)) + + diff_vectors = pop_permute_padding[:, 1:] + difference_sum = diff_vectors[:, 0::2].sum(dim=1) - diff_vectors[:, 1::2].sum(dim=1) + return difference_sum, rand_indices[:, 0] + + +def DE_binary_crossover(mutation_vector: torch.Tensor, current_vector: torch.Tensor, CR: torch.Tensor): + """ + Performs binary crossover in differential evolution. + + :param mutation_vector: The mutated vector for each individual in the population. + :param current_vector: The current vector for each individual in the population. + :param CR: The crossover probability for each individual. + + :return: The trial vector after crossover for each individual. + """ + device = mutation_vector.device + pop_size, dim = mutation_vector.size() + if CR.ndim == 1: + CR = CR.unsqueeze(1) + mask = torch.randn(pop_size, dim, device=device) < CR + rind = torch.randint(0, dim, (pop_size,), device=device).unsqueeze(1) + jind = torch.arange(dim, device=device).unsqueeze(0) == rind + trial_vector = torch.where(torch.logical_or(mask, jind), mutation_vector, current_vector) + return trial_vector + + +def DE_exponential_crossover(mutation_vector: torch.Tensor, current_vector: torch.Tensor, CR: torch.Tensor): + """ + Performs exponential crossover in differential evolution. + + :param mutation_vector: The mutated vector for each individual in the population. + :param current_vector: The current vector for each individual in the population. + :param CR: The crossover probability for each individual. + + :return: The trial vector after crossover for each individual. + """ + device = mutation_vector.device + pop_size, dim = mutation_vector.size() + nn = torch.randint(0, dim, (pop_size,), device=device) + # Geometric distribution random ll + float_tiny = 1.1754943508222875e-38 + ll = torch.rand(pop_size, device=device).clamp(min=float_tiny) + ll = (ll.log() / (-CR.log1p())).floor().to(dtype=nn.dtype) + mask = torch.arange(dim, device=device).unsqueeze(0) < (minimum_int(ll, dim) - 1).unsqueeze(1) + mask = torch.gather(torch.tile(mask, (1, 2)), 1, nn.unsqueeze(1) + torch.arange(dim, device=device)) + trial_vector = torch.where(mask, mutation_vector, current_vector) + return trial_vector + + +def DE_arithmetic_recombination(mutation_vector: torch.Tensor, current_vector: torch.Tensor, K: torch.Tensor): + """ + Performs arithmetic recombination in differential evolution. + + :param mutation_vector: The mutated vector for each individual in the population. + :param current_vector: The current vector for each individual in the population. + :param K: The coefficient for each individual. + + :return: The trial vector after recombination for each individual. + """ + if K.ndim == 1: + K = K.unsqueeze(1) + trial_vector = current_vector + K * (mutation_vector - current_vector) + return trial_vector diff --git a/src/evox/operators/selection/__init__.py b/src/evox/operators/selection/__init__.py index bb944c3a4..00201b842 100644 --- a/src/evox/operators/selection/__init__.py +++ b/src/evox/operators/selection/__init__.py @@ -5,10 +5,12 @@ "non_dominate_rank", "non_dominated_sort_script", "ref_vec_guided", + "select_rand_pbest", "tournament_selection", "tournament_selection_multifit", ] +from .find_pbest import select_rand_pbest from .non_dominate import ( NonDominatedSort, crowding_distance, diff --git a/src/evox/operators/selection/find_pbest.py b/src/evox/operators/selection/find_pbest.py new file mode 100644 index 000000000..ce100bb9e --- /dev/null +++ b/src/evox/operators/selection/find_pbest.py @@ -0,0 +1,19 @@ +import torch + + +def select_rand_pbest(percent: float, population: torch.Tensor, fitness: torch.Tensor) -> torch.Tensor: + """ + Selects a random personal-best vector from the population for each individual. + + :param percent: The proportion of the population to consider as best. Must be between 0 and 1. + :param population: The population tensor of shape `(pop_size, dim)`. + :param fitness: The fitness tensor of shape `(pop_size,)`. + + :return: A tensor containing the selected personal-best vector for each individual. + """ + device = population.device + pop_size = population.size(0) + top_p_num = max(int(pop_size * percent), 1) + pbest_indices_pool = torch.argsort(fitness)[:top_p_num] + random_indices = torch.randint(0, top_p_num, (pop_size,), device=device) + return population[pbest_indices_pool[random_indices]] diff --git a/src/evox/utils/__init__.py b/src/evox/utils/__init__.py index 6c796b3cd..f1495ae6b 100644 --- a/src/evox/utils/__init__.py +++ b/src/evox/utils/__init__.py @@ -6,6 +6,10 @@ "clip", "maximum", "minimum", + "maximum_float", + "minimum_float", + "maximum_int", + "minimum_int", "TracingWhile", "TracingCond", "TracingSwitch", @@ -16,7 +20,22 @@ ] from .control_flow import TracingCond, TracingSwitch, TracingWhile -from .jit_fix_operator import clamp, clamp_float, clamp_int, clip, lexsort, maximum, minimum, nanmax, nanmin, switch +from .jit_fix_operator import ( + clamp, + clamp_float, + clamp_int, + clip, + lexsort, + maximum, + maximum_float, + maximum_int, + minimum, + minimum_float, + minimum_int, + nanmax, + nanmin, + switch, +) from .parameters_and_vector import ParamsAndVector diff --git a/src/evox/utils/jit_fix_operator.py b/src/evox/utils/jit_fix_operator.py index 7f682aa83..814f71f9b 100644 --- a/src/evox/utils/jit_fix_operator.py +++ b/src/evox/utils/jit_fix_operator.py @@ -132,6 +132,68 @@ def minimum(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a - diff +def maximum_float(a: torch.Tensor, b: float) -> torch.Tensor: + """ + Element-wise maximum of input tensor `a` and float `b`. + + Notice: This is a fix function for [`torch.maximum`](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion. + + :param a: The first input tensor. + :param b: The second input float, which is a scalar value. + + :return: The element-wise maximum of `a` and `b`. + """ + diff = torch.relu(b - a) + return a + diff + + +def minimum_float(a: torch.Tensor, b: float) -> torch.Tensor: + """ + Element-wise minimum of input tensor `a` and float `b`. + + Notice: This is a fix function for [`torch.minimum`](https://pytorch.org/docs/stable/generated/torch.minimum.html) + since it is not supported in JIT operator fusion. + + :param a: The first input tensor. + :param b: The second input float, which is a scalar value. + + :return: The element-wise minimum of `a` and `b`. + """ + diff = torch.relu(a - b) + return a - diff + + +def maximum_int(a: torch.Tensor, b: int) -> torch.Tensor: + """ + Element-wise maximum of input tensor `a` and int `b`. + + Notice: This is a fix function for [`torch.maximum`](https://pytorch.org/docs/stable/generated/torch.maximum.html] since it is not supported in JIT operator fusion. + + :param a: The first input tensor. + :param b: The second input int, which is a scalar value. + + :return: The element-wise maximum of `a` and `b`. + """ + diff = torch.relu(b - a) + return a + diff + + +def minimum_int(a: torch.Tensor, b: int) -> torch.Tensor: + """ + Element-wise minimum of input tensor `a` and int `b`. + + Notice: This is a fix function for [`torch.minimum`](https://pytorch.org/docs/stable/generated/torch.minimum.html) + since it is not supported in JIT operator fusion. + + :param a: The first input tensor. + :param b: The second input int, which is a scalar value. + + :return: The element-wise minimum of `a` and `b`. + """ + diff = torch.relu(a - b) + return a - diff + + def lexsort(keys: List[torch.Tensor], dim: int = -1) -> torch.Tensor: """ Perform lexicographical sorting of multiple tensors, considering each tensor as a key. diff --git a/unit_test/algorithms/test_de_variants.py b/unit_test/algorithms/test_de_variants.py index 83511c59e..7b1401549 100644 --- a/unit_test/algorithms/test_de_variants.py +++ b/unit_test/algorithms/test_de_variants.py @@ -1,6 +1,6 @@ import torch -from evox.algorithms import DE, ODE +from evox.algorithms import DE, ODE, SHADE, CoDE, SaDE from .test_base import TestBase @@ -17,6 +17,9 @@ def setUp(self): DE(pop_size, lb, ub, base_vector="best"), ODE(pop_size, lb, ub, base_vector="rand"), ODE(pop_size, lb, ub, base_vector="best"), + SHADE(pop_size, lb, ub), + CoDE(pop_size, lb, ub), + SaDE(pop_size, lb, ub), ] def test_de_variants(self): diff --git a/unit_test/algorithms/test_es_variants.py b/unit_test/algorithms/test_es_variants.py index bc13f16fc..f83cf539a 100644 --- a/unit_test/algorithms/test_es_variants.py +++ b/unit_test/algorithms/test_es_variants.py @@ -1,13 +1,25 @@ import torch -from evox.algorithms import CMAES, OpenES +from evox.algorithms import ( + ARS, + ASEBO, + CMAES, + DES, + ESMC, + SNES, + XNES, + GuidedES, + NoiseReuseES, + OpenES, + PersistentES, + SeparableNES, +) from .test_base import TestBase class TestESVariants(TestBase): def setUp(self): - torch.manual_seed(42) pop_size = 10 dim = 4 lb = -10 * torch.ones(dim) @@ -27,6 +39,84 @@ def setUp(self): noise_stdev=5, optimizer="adam", ), + XNES( + pop_size=pop_size, + init_mean=torch.rand(dim) * (ub - lb) + lb, + init_covar=torch.eye(dim), + ), + SeparableNES( + pop_size=pop_size, + init_mean=torch.rand(dim) * (ub - lb) + lb, + init_std=torch.full((dim,), 1), + ), + DES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + ESMC( + pop_size=pop_size | 1, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + ESMC( + pop_size=pop_size | 1, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), + SNES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + weight_type="recomb", + ), + SNES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + weight_type="temp", + ), + PersistentES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + PersistentES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), + GuidedES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + GuidedES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), + NoiseReuseES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + NoiseReuseES( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), + ARS( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + ARS( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), + ASEBO( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + ), + ASEBO( + pop_size=pop_size, + center_init=torch.rand(dim) * (ub - lb) + lb, + optimizer="adam", + ), CMAES( mean_init=torch.rand(dim) * (ub - lb) + lb, sigma=5,