Skip to content
Open

Meoh #49

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 100 additions & 95 deletions llm4ad/method/meoh/meoh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#
# For inquiries regarding commercial use or licensing, please contact
# http://www.llm4ad.com/contact.html
# --------------------------------------------------------------------------

from __future__ import annotations

Expand All @@ -30,6 +29,7 @@
import time
import traceback
from threading import Thread
from typing import Optional, Literal
import numpy as np

from .population import Population
Expand All @@ -47,9 +47,9 @@ def __init__(self,
llm: LLM,
evaluation: Evaluation,
profiler: ProfilerBase = None,
max_generations: int | None = 10,
max_sample_nums: int | None = 100,
pop_size: int = 20,
max_generations: Optional[int] = 10,
max_sample_nums: Optional[int] = 100,
pop_size: Optional[int] = 5,
selection_num=5,
use_e2_operator: bool = True,
use_m1_operator: bool = True,
Expand All @@ -59,9 +59,9 @@ def __init__(self,
num_objs: int = 2,
*,
resume_mode: bool = False,
initial_sample_num: int | None = None,
initial_sample_nums_max: int = 50,
debug_mode: bool = False,
multi_thread_or_process_eval: str = 'thread',
multi_thread_or_process_eval: Literal['thread', 'process'] = 'thread',
**kwargs):
"""
Args:
Expand All @@ -82,7 +82,8 @@ def __init__(self,
setting this parameter to 'process' will faster than 'thread'. However, I do not sure if this happens on all platform so I set the default to 'thread'.
Please note that there is one case that cannot utilize multi-core CPU: if you set 'safe_evaluate' argument in 'evaluator' to 'False',
and you set this argument to 'thread'.
**kwargs : some args pass to 'llm4ad.base.SecureEvaluator'. Such as 'fork_proc'.
initial_sample_nums_max : maximum samples restriction during initialization.
**kwargs : some args pass to 'llm4ad.base.SecureEvaluator'. Such as 'fork_proc'.
"""
self._template_program_str = evaluation.template_program
self._task_description_str = evaluation.task_description
Expand All @@ -94,29 +95,40 @@ def __init__(self,
self._use_e2_operator = use_e2_operator
self._use_m1_operator = use_m1_operator
self._use_m2_operator = use_m2_operator

# samplers and evaluators
self._num_samplers = num_samplers
self._num_evaluators = num_evaluators
self._resume_mode = resume_mode
self._initial_sample_num = initial_sample_num
self._initial_sample_nums_max = initial_sample_nums_max
self._debug_mode = debug_mode
llm.debug_mode = debug_mode
self._multi_thread_or_process_eval = multi_thread_or_process_eval

# function to be evolved
self._function_to_evolve: Function = TextFunctionProgramConverter.text_to_function(self._template_program_str)
self._function_to_evolve_name: str = self._function_to_evolve.name
self._template_program: Program = TextFunctionProgramConverter.text_to_program(self._template_program_str)

# adjust population size
self._adjust_pop_size()

# population, sampler, and evaluator
self._population = Population(pop_size=self._pop_size)
llm.debug_mode = debug_mode
self._sampler = MEoHSampler(llm, self._template_program_str)
self._evaluator = SecureEvaluator(evaluation, debug_mode=debug_mode, **kwargs)
self._profiler = profiler
if profiler is not None:
self._profiler.record_parameters(llm, evaluation, self) # ZL: Necessary

# statistics
self._tot_sample_nums = 0 if initial_sample_num is None else initial_sample_num
self._tot_sample_nums = 0

# reset _initial_sample_nums_max
self._initial_sample_nums_max = max(
self._initial_sample_nums_max,
2 * self._pop_size
)

# multi-thread executor for evaluation
assert multi_thread_or_process_eval in ['thread', 'process']
Expand All @@ -129,120 +141,122 @@ def __init__(self,
max_workers=num_evaluators
)

# pass parameters to profiler
if profiler is not None:
self._profiler.record_parameters(llm, evaluation, self) # ZL: necessary

def _adjust_pop_size(self):
# adjust population size
if self._max_sample_nums >= 10000:
if self._pop_size is None:
self._pop_size = 40
elif abs(self._pop_size - 40) > 20:
print(f'Warning: population size {self._pop_size} '
f'is not suitable, please reset it to 40.')
elif self._max_sample_nums >= 1000:
if self._pop_size is None:
self._pop_size = 20
elif abs(self._pop_size - 20) > 10:
print(f'Warning: population size {self._pop_size} '
f'is not suitable, please reset it to 20.')
elif self._max_sample_nums >= 200:
if self._pop_size is None:
self._pop_size = 10
elif abs(self._pop_size - 10) > 5:
print(f'Warning: population size {self._pop_size} '
f'is not suitable, please reset it to 10.')
else:
if self._pop_size is None:
self._pop_size = 5
elif abs(self._pop_size - 5) > 5:
print(f'Warning: population size {self._pop_size} '
f'is not suitable, please reset it to 5.')

def _sample_evaluate_register(self, prompt):
"""Sample a function using the given prompt -> evaluate it by submitting to the process/thread pool ->
add the function to the population and register it to the profiler.
"""Perform following steps:
1. Sample an algorithm using the given prompt.
2. Evaluate it by submitting to the process/thread pool, and get the results.
3. Add the function to the population and register it to the profiler.
"""
sample_start = time.time()
thought, func = self._sampler.get_thought_and_function(prompt)
sample_time = time.time() - sample_start
if thought is None or func is None:
return

# convert to Program instance
program = TextFunctionProgramConverter.function_to_program(func, self._template_program)
if program is None:
return

# evaluate
score, eval_time = self._evaluation_executor.submit(
self._evaluator.evaluate_program_record_time,
program
).result()

# score
# register to profiler
func.score = score
func.evaluate_time = eval_time
func.algorithm = thought
func.sample_time = sample_time
try:
if self._profiler is not None:
self._profiler.register_function(func, program=str(program))
if isinstance(self._profiler, MEoHProfiler):
self._profiler.register_population(self._population)
self._tot_sample_nums += 1
except Exception as e:
traceback.print_exc()
if self._profiler is not None:
self._profiler.register_function(func)
if isinstance(self._profiler, MEoHProfiler):
self._profiler.register_population(self._population)
self._tot_sample_nums += 1

# register to the population
self._population.register_function(func)

def _continue_sample(self):
"""Check if it meets the max_sample_nums restrictions.
"""
def _continue_loop(self) -> bool:
if self._max_generations is None and self._max_sample_nums is None:
return True
if self._max_generations is None and self._max_sample_nums is not None:
if self._tot_sample_nums < self._max_sample_nums:
return True
else:
return False
if self._max_generations is not None and self._max_sample_nums is None:
if self._population.generation < self._max_generations:
return True
else:
return False
if self._max_generations is not None and self._max_sample_nums is not None:
continue_until_reach_gen = False
continue_until_reach_sample = False
if self._population.generation < self._max_generations:
continue_until_reach_gen = True
if self._tot_sample_nums < self._max_sample_nums:
continue_until_reach_sample = True
return continue_until_reach_gen and continue_until_reach_sample
elif self._max_generations is not None and self._max_sample_nums is None:
return self._population.generation < self._max_generations
elif self._max_generations is None and self._max_sample_nums is not None:
return self._tot_sample_nums < self._max_sample_nums
else:
return (self._population.generation < self._max_generations
and self._tot_sample_nums < self._max_sample_nums)

def _thread_do_evolutionary_operator(self):
while self._continue_sample():
while self._continue_loop():
try:
# get a new func using e1
indivs = [self._population.selection() for _ in range(self._selection_num)]
prompt = MEoHPrompt.get_prompt_e1(self._task_description_str, indivs, self._function_to_evolve)

if self._debug_mode:
print(prompt)
input()

print(f'E1 Prompt: {prompt}')
self._sample_evaluate_register(prompt)
if not self._continue_sample():
if not self._continue_loop():
break

# get a new func using e2
if self._use_e2_operator:
indivs = [self._population.selection() for _ in range(self._selection_num)]
prompt = MEoHPrompt.get_prompt_e2(self._task_description_str, indivs, self._function_to_evolve)

if self._debug_mode:
print(prompt)
input()

print(f'E2 Prompt: {prompt}')
self._sample_evaluate_register(prompt)
if not self._continue_sample():
if not self._continue_loop():
break

# get a new func using m1
if self._use_m1_operator:
indiv = self._population.selection()
prompt = MEoHPrompt.get_prompt_m1(self._task_description_str, indiv, self._function_to_evolve)

if self._debug_mode:
print(prompt)
input()

print(f'M1 Prompt: {prompt}')
self._sample_evaluate_register(prompt)
if not self._continue_sample():
if not self._continue_loop():
break

# get a new func using m2
if self._use_m2_operator:
indiv = self._population.selection()
prompt = MEoHPrompt.get_prompt_m2(self._task_description_str, indiv, self._function_to_evolve)

if self._debug_mode:
print(prompt)
input()

print(f'M2 Prompt: {prompt}')
self._sample_evaluate_register(prompt)
if not self._continue_sample():
if not self._continue_loop():
break
except KeyboardInterrupt:
break
Expand All @@ -258,40 +272,29 @@ def _thread_do_evolutionary_operator(self):
except:
pass

def _thread_init_population(self):
def _iteratively_init_population(self):
"""Let a thread repeat {sample -> evaluate -> register to population}
to initialize a population.
"""
while self._population.generation == 0:
if not self._continue_sample():
break
try:
# get a new func using i1
prompt = MEoHPrompt.get_prompt_i1(self._task_description_str, self._function_to_evolve)
self._sample_evaluate_register(prompt)
except Exception as e:
if self._tot_sample_nums > self._initial_sample_nums_max:
print(f'Warning: Initialization not accomplished in {self._initial_sample_nums_max} samples !!!')
break
except Exception:
if self._debug_mode:
traceback.print_exc()
exit()
continue

def _init_population(self):
def _multi_threaded_sampling(self, fn: callable, *args, **kwargs):
# threads for sampling
sampler_threads = [
Thread(
target=self._thread_init_population,
) for _ in range(self._num_samplers)
]
for t in sampler_threads:
t.start()
for t in sampler_threads:
t.join()

def _do_sample(self):
sampler_threads = [
Thread(
target=self._thread_do_evolutionary_operator,
) for _ in range(self._num_samplers)
Thread(target=fn, args=args, kwargs=kwargs)
for _ in range(self._num_samplers)
]
for t in sampler_threads:
t.start()
Expand All @@ -300,17 +303,19 @@ def _do_sample(self):

def run(self):
if not self._resume_mode:
# do init
self._population = Population(pop_size=self._pop_size)
self._init_population()
while len([f for f in self._population if not np.isinf(np.array(f.score)).any()]) < self._selection_num:
self._population._generation -= 1
self._init_population()
# do initialization
self._multi_threaded_sampling(self._iteratively_init_population)
# while len([f for f in self._population if not np.isinf(np.array(f.score)).any()]) < self._selection_num:
# self._population._generation -= 1
# self._init_population()
if len(self._population) < self._selection_num:
print(
f'The search is terminated since MEoH unable to obtain {self._selection_num} feasible algorithms during initialization. '
f'Please increase the `initial_sample_nums_max` argument (currently {self._initial_sample_nums_max}). '
f'Please also check your evaluation implementation and LLM implementation.')
return
# do evolve
self._do_sample()

self._multi_threaded_sampling(self._thread_do_evolutionary_operator)
# finish
if self._profiler is not None:
self._profiler.finish()

self._sampler.llm.close()
Loading