diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..e399fad0 --- /dev/null +++ b/.env.example @@ -0,0 +1,7 @@ +HF_TOKEN="" +OPENAI_API_KEY="sk-proj-" +GEMINI_API_KEY="" +S2_API_KEY="" +MAX_NUM_TOKENS=4096 +LLM_MODEL="gemini-2.0-flash-exp" +LLM_SMALL_MODEL="gemini-2.0-flash-exp" \ No newline at end of file diff --git a/README.md b/README.md index d475f333..e57ededa 100644 --- a/README.md +++ b/README.md @@ -58,11 +58,14 @@ conda install conda-forge::chktex pip install -r requirements.txt ``` +You may also need to install pdflatex like here https://gist.github.com/rain1024/98dd5e2c6c8c28f9ea9d + ### Supported Models and API Keys #### OpenAI Models -By default, the system uses the `OPENAI_API_KEY` environment variable for OpenAI models. +By default, the system uses the `OPENAI_API_KEY` environment variable. You can see other environment variables in `.env.example`. + #### Claude Models via AWS Bedrock @@ -148,6 +151,20 @@ python launch_scientist_bfts.py \ Once the initial experimental stage is complete, you will find a timestamped log folder inside the `experiments/` directory. Navigate to `experiments/"timestamp_ideaname"/logs/0-run/` within that folder to find the tree visualization file `unified_tree_viz.html`. +### Running with Gemini +Set `GEMINI_API_KEY` environment variable to your Gemini API key, and set `LLM_MODEL` to a gemini model like `gemini-2.0-flash-exp`. You can use the following command to run the experiment with the Gemini model: + +```bash +python launch_scientist_bfts.py \ + --load_ideas "ai_scientist/ideas/automated_concept_sae_eval.json" \ + --add_dataset_ref \ + --model_writeup "gemini-2.0-flash-exp" \ + --model_citation "gemini-2.0-flash-exp" \ + --model_review "gemini-2.0-flash-exp" \ + --model_agg_plots "gemini-2.0-flash-exp" \ + --num_cite_rounds 20 \ +--config_path "bfts_config_gemini.yaml" +``` ## Citing The AI Scientist-v2 If you use **The AI Scientist-v2** in your research, please cite our work as follows: diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index e265f096..86058767 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -7,8 +7,10 @@ import anthropic import backoff import openai +from dotenv import load_dotenv -MAX_NUM_TOKENS = 4096 +load_dotenv() +MAX_NUM_TOKENS = int(os.environ.get("MAX_NUM_TOKENS", 4096)) AVAILABLE_LLMS = [ "claude-3-5-sonnet-20240620", @@ -147,7 +149,7 @@ def get_batch_responses_from_llm( print() print("*" * 20 + " LLM START " + "*" * 20) for j, msg in enumerate(new_msg_history[0]): - print(f'{j}, {msg["role"]}: {msg["content"]}') + print(f"{j}, {msg['role']}: {msg['content']}") print(content) print("*" * 21 + " LLM END " + "*" * 21) print() @@ -170,6 +172,16 @@ def make_llm_call(client, model, temperature, system_message, prompt): stop=None, seed=0, ) + elif "gemini" in model: + return client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + *prompt, + ], + temperature=temperature, + max_tokens=MAX_NUM_TOKENS, + ) elif "o1" in model or "o3" in model: return client.chat.completions.create( model=model, @@ -239,18 +251,7 @@ def get_response_from_llm( ], } ] - elif "gpt" in model: - new_msg_history = msg_history + [{"role": "user", "content": msg}] - response = make_llm_call( - client, - model, - temperature, - system_message=system_message, - prompt=new_msg_history, - ) - content = response.choices[0].message.content - new_msg_history = new_msg_history + [{"role": "assistant", "content": content}] - elif "o1" in model or "o3" in model: + elif any(x in model for x in ["gpt", "gemini", "o1", "o3"]): new_msg_history = msg_history + [{"role": "user", "content": msg}] response = make_llm_call( client, @@ -342,7 +343,7 @@ def get_response_from_llm( print() print("*" * 20 + " LLM START " + "*" * 20) for j, msg in enumerate(new_msg_history): - print(f'{j}, {msg["role"]}: {msg["content"]}') + print(f"{j}, {msg['role']}: {msg['content']}") print(content) print("*" * 21 + " LLM END " + "*" * 21) print() @@ -396,6 +397,13 @@ def create_client(model) -> tuple[Any, str]: elif "o1" in model or "o3" in model: print(f"Using OpenAI API with model {model}.") return openai.OpenAI(), model + elif "gemini" in model: + print(f"Using Gemini API with model {model}.") + return openai.OpenAI( + max_retries=0, + api_key=os.getenv("GEMINI_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ), model elif model == "deepseek-coder-v2-0724": print(f"Using OpenAI API with {model}.") return ( diff --git a/ai_scientist/perform_icbinb_writeup.py b/ai_scientist/perform_icbinb_writeup.py index 4080b97b..91a9e715 100644 --- a/ai_scientist/perform_icbinb_writeup.py +++ b/ai_scientist/perform_icbinb_writeup.py @@ -7,9 +7,9 @@ import subprocess import traceback import unicodedata -import uuid import tempfile - +from dotenv import load_dotenv +load_dotenv() from ai_scientist.llm import ( get_response_from_llm, extract_json_between_markers, @@ -17,7 +17,6 @@ AVAILABLE_LLMS, ) -from ai_scientist.utils.token_tracker import track_token_usage from ai_scientist.tools.semantic_scholar import search_for_papers @@ -733,6 +732,8 @@ def filter_experiment_summaries(exp_summaries, step_name): elif stage_name == "ABLATION_SUMMARY" and step_name == "plot_aggregation": filtered_summaries[stage_name] = {} for ablation_summary in exp_summaries[stage_name]: + if ablation_summary is None: + continue filtered_summaries[stage_name][ablation_summary["ablation_name"]] = {} for node_key in ablation_summary.keys(): if node_key in node_keys_to_keep: @@ -742,7 +743,7 @@ def filter_experiment_summaries(exp_summaries, step_name): return filtered_summaries -def gather_citations(base_folder, num_cite_rounds=20, small_model="gpt-4o-2024-05-13"): +def gather_citations(base_folder, num_cite_rounds=20, small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13")): """ Gather citations for a paper, with ability to resume from previous progress. @@ -859,8 +860,8 @@ def perform_writeup( citations_text=None, no_writing=False, num_cite_rounds=20, - small_model="gpt-4o-2024-05-13", - big_model="o1-2024-12-17", + small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13"), + big_model=os.getenv("LLM_MODEL", "o1-2024-12-17"), n_writeup_reflections=3, page_limit=4, ): @@ -950,7 +951,7 @@ def perform_writeup( # Generate VLM-based descriptions try: - vlm_client, vlm_model = create_vlm_client("gpt-4o-2024-05-13") + vlm_client, vlm_model = create_vlm_client(os.getenv("LLM_MODEL", "gpt-4o-2024-05-13")) desc_map = {} for pf in plot_names: ppath = osp.join(figures_dir, pf) @@ -1207,7 +1208,7 @@ def perform_writeup( base_folder, f"{osp.basename(base_folder)}_reflection_final_page_limit.pdf" ) # Compile current version before reflection - print(f"[green]Compiling PDF for reflection final page limit...[/green]") + print("[green]Compiling PDF for reflection final page limit...[/green]") print(f"reflection step {i+1}") @@ -1232,7 +1233,7 @@ def perform_writeup( compile_latex(latex_folder, reflection_pdf) else: - print(f"No changes in reflection page step.") + print("No changes in reflection page step.") return osp.exists(reflection_pdf) diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index 663539a9..41c2473e 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -15,9 +15,7 @@ create_client, AVAILABLE_LLMS, ) - from ai_scientist.tools.semantic_scholar import search_for_papers - from ai_scientist.perform_vlm_review import generate_vlm_img_review from ai_scientist.vlm import create_client as create_vlm_client @@ -456,8 +454,8 @@ def perform_writeup( base_folder, no_writing=False, num_cite_rounds=20, - small_model="gpt-4o-2024-05-13", - big_model="o1-2024-12-17", + small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13"), + big_model=os.getenv("LLM_MODEL", "o1-2024-12-17"), n_writeup_reflections=3, page_limit=8, ): @@ -525,6 +523,11 @@ def perform_writeup( for fplot in os.listdir(figures_dir): if fplot.lower().endswith(".png"): plot_names.append(fplot) + # copy plots to ./latex folder + for fplot in plot_names: + src_path = osp.join(figures_dir, fplot) + dest_path = osp.join(latex_folder, fplot) + shutil.copy(src_path, dest_path) # Load aggregator script to include in the prompt aggregator_path = osp.join(base_folder, "auto_plot_aggregator.py") @@ -589,7 +592,7 @@ def perform_writeup( # Generate VLM-based descriptions but do not overwrite plot_names try: - vlm_client, vlm_model = create_vlm_client("gpt-4o-2024-05-13") + vlm_client, vlm_model = create_vlm_client(os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13")) desc_map = {} for pf in plot_names: ppath = osp.join(figures_dir, pf) diff --git a/ai_scientist/treesearch/agent_manager.py b/ai_scientist/treesearch/agent_manager.py index a253b56d..b535df09 100644 --- a/ai_scientist/treesearch/agent_manager.py +++ b/ai_scientist/treesearch/agent_manager.py @@ -880,6 +880,7 @@ def _create_stage_analysis_prompt( # Save stage transition analysis to notes directory base_dir = Path(self.workspace_dir).parent.parent run_name = Path(self.workspace_dir).name + stage_number = previous_stages[-1].stage_number notes_dir = ( base_dir / "logs" diff --git a/ai_scientist/treesearch/backend/backend_openai.py b/ai_scientist/treesearch/backend/backend_openai.py index ae318ec4..440b6353 100644 --- a/ai_scientist/treesearch/backend/backend_openai.py +++ b/ai_scientist/treesearch/backend/backend_openai.py @@ -6,7 +6,10 @@ from funcy import notnone, once, select_values import openai from rich import print +import os +from dotenv import load_dotenv +load_dotenv() logger = logging.getLogger("ai-scientist") _client: openai.OpenAI = None # type: ignore @@ -22,7 +25,17 @@ @once def _setup_openai_client(): global _client - _client = openai.OpenAI(max_retries=0) + gemini_api_key = os.getenv("GEMINI_API_KEY") + if gemini_api_key is None: + _client = openai.OpenAI( + max_retries=0, + ) + else: + _client = openai.OpenAI( + max_retries=0, + api_key=os.getenv("GEMINI_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ) def query( @@ -55,12 +68,12 @@ def query( if func_spec is None: output = choice.message.content else: - assert ( - choice.message.tool_calls - ), f"function_call is empty, it is not a function call: {choice.message}" - assert ( - choice.message.tool_calls[0].function.name == func_spec.name - ), "Function name mismatch" + assert choice.message.tool_calls, ( + f"function_call is empty, it is not a function call: {choice.message}" + ) + assert choice.message.tool_calls[0].function.name == func_spec.name, ( + "Function name mismatch" + ) try: print(f"[cyan]Raw func call response: {choice}[/cyan]") output = json.loads(choice.message.tool_calls[0].function.arguments) diff --git a/ai_scientist/treesearch/backend/utils.py b/ai_scientist/treesearch/backend/utils.py index 60419c7c..89fef394 100644 --- a/ai_scientist/treesearch/backend/utils.py +++ b/ai_scientist/treesearch/backend/utils.py @@ -34,10 +34,7 @@ def opt_messages_to_list( system_message: str | None, user_message: str | None ) -> list[dict[str, str]]: messages = [] - if system_message: - messages.append({"role": "system", "content": system_message}) - if user_message: - messages.append({"role": "user", "content": user_message}) + messages.append({"role": "user", "content": (system_message or "") + "\n" + (user_message or "")}) return messages diff --git a/ai_scientist/treesearch/journal.py b/ai_scientist/treesearch/journal.py index 0dc1695d..eee88245 100644 --- a/ai_scientist/treesearch/journal.py +++ b/ai_scientist/treesearch/journal.py @@ -2,7 +2,7 @@ import time import uuid from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import Literal, Optional, List, Dict import copy import os import json @@ -17,7 +17,9 @@ import logging from pathlib import Path +from dotenv import load_dotenv +load_dotenv() logger = logging.getLogger(__name__) node_selection_spec = FunctionSpec( @@ -347,13 +349,13 @@ def generate_nb_trace(self, include_prompt, comment_headers=True) -> str: trace = [] header_prefix = "## " if comment_headers else "" for n in self.nodes: - trace.append(f"\n{header_prefix}In [{n.step+1}]:\n") + trace.append(f"\n{header_prefix}In [{n.step + 1}]:\n") trace.append(n.code) - trace.append(f"\n{header_prefix}Out [{n.step+1}]:\n") + trace.append(f"\n{header_prefix}Out [{n.step + 1}]:\n") trace.append(n.term_out) if include_prompt and self.nodes: - trace.append(f"\n{header_prefix}In [{self.nodes[-1].step+2}]:\n") + trace.append(f"\n{header_prefix}In [{self.nodes[-1].step + 2}]:\n") return "\n".join(trace).strip() @@ -363,6 +365,7 @@ class Journal: """A collection of nodes representing the solution tree.""" nodes: list[Node] = field(default_factory=list) + model: str = field(default=os.getenv("LLM_MODEL", "gpt-4o")) def __getitem__(self, idx: int) -> Node: return self.nodes[idx] @@ -452,13 +455,13 @@ def get_best_node(self, only_good=True, use_val_metric_only=False) -> None | Nod for node in nodes: if not node.is_seed_node: candidate_info = ( - f"ID: {node.id}\n" f"Metric: {str(node.metric)}\n" + f"ID: {node.id}\nMetric: {str(node.metric)}\n" if node.metric else ( - "N/A\n" f"Training Analysis: {node.analysis}\n" + f"N/A\nTraining Analysis: {node.analysis}\n" if hasattr(node, "analysis") else ( - "N/A\n" f"VLM Feedback: {node.vlm_feedback_summary}\n" + f"N/A\nVLM Feedback: {node.vlm_feedback_summary}\n" if hasattr(node, "vlm_feedback_summary") else "N/A\n" ) @@ -471,7 +474,7 @@ def get_best_node(self, only_good=True, use_val_metric_only=False) -> None | Nod system_message=prompt, user_message=None, func_spec=node_selection_spec, - model="gpt-4o", + model=self.model, temperature=0.3, ) @@ -535,7 +538,7 @@ def generate_summary(self, include_code: bool = False) -> str: "2. Common failure patterns and pitfalls to avoid\n" "3. Specific recommendations for future experiments based on both successes and failures" ), - model="gpt-4o", + model=self.model, temperature=0.3, ) @@ -599,7 +602,7 @@ def save_experiment_notes(self, workspace_dir: str, stage_name: str): stage_summary = query( system_message=summary_prompt, user_message="Generate a comprehensive summary of the experimental findings in this stage", - model="gpt-4", + model=self.model, temperature=0.3, ) diff --git a/ai_scientist/treesearch/log_summarization.py b/ai_scientist/treesearch/log_summarization.py index 436cdd18..ce49f9c7 100644 --- a/ai_scientist/treesearch/log_summarization.py +++ b/ai_scientist/treesearch/log_summarization.py @@ -1,17 +1,19 @@ import json import os import sys - -import openai - +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm from .journal import Node, Journal +from dotenv import load_dotenv +load_dotenv() parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) sys.path.insert(0, parent_dir) from ai_scientist.llm import get_response_from_llm, extract_json_between_markers +from ai_scientist.llm import create_client -client = openai.OpenAI() -model = "gpt-4o-2024-08-06" +model = os.getenv("LLM_MODEL", "gpt-4o-2024-08-06") +client, model = create_client(model) report_summarizer_sys_msg = """You are an expert machine learning researcher. You are given multiple experiment logs, each representing a node in a stage of exploring scientific ideas and implementations. @@ -295,7 +297,6 @@ def annotate_history(journal): def overall_summarize(journals): - from concurrent.futures import ThreadPoolExecutor def process_stage(idx, stage_tuple): stage_name, journal = stage_tuple @@ -339,8 +340,6 @@ def process_stage(idx, stage_tuple): summary_json = get_stage_summary(journal, stage_name, model, client) return summary_json - from tqdm import tqdm - with ThreadPoolExecutor() as executor: results = list( tqdm( @@ -349,9 +348,14 @@ def process_stage(idx, stage_tuple): total=len(list(journals)), ) ) - draft_summary, baseline_summary, research_summary, ablation_summary = results - - return draft_summary, baseline_summary, research_summary, ablation_summary + print(f"WARNING: {len(results)} stages found") + # Safely unpack results, which may have 1 to 4 items + draft_summary = results[0] if len(results) > 0 else dict() + baseline_summary = results[1] if len(results) > 1 else dict() + research_summary = results[2] if len(results) > 2 else dict() + ablation_summary = results[3] if len(results) > 3 else dict() + + return draft_summary, baseline_summary, research_summary, ablation_summary if __name__ == "__main__": diff --git a/ai_scientist/treesearch/perform_experiments_bfts_with_agentmanager.py b/ai_scientist/treesearch/perform_experiments_bfts_with_agentmanager.py index 62564f15..36be902f 100644 --- a/ai_scientist/treesearch/perform_experiments_bfts_with_agentmanager.py +++ b/ai_scientist/treesearch/perform_experiments_bfts_with_agentmanager.py @@ -94,6 +94,8 @@ def cleanup(): def create_exec_callback(status_obj): def exec_callback(*args, **kwargs): status_obj.update("[magenta]Executing code...") + from .interpreter import Interpreter # Import the Interpreter class or module + interpreter = Interpreter(cfg.workspace_dir) # Initialize the interpreter res = interpreter.run(*args, **kwargs) status_obj.update("[green]Generating code...") return res @@ -235,7 +237,7 @@ def generate_live(manager): with open(ablation_summary_path, "w") as ablation_file: json.dump(ablation_summary, ablation_file, indent=2) - print(f"Summary reports written to files:") + print("Summary reports written to files:") print(f"- Draft summary: {draft_summary_path}") print(f"- Baseline summary: {baseline_summary_path}") print(f"- Research summary: {research_summary_path}") diff --git a/ai_scientist/utils/token_tracker.py b/ai_scientist/utils/token_tracker.py index 0b0806b7..03f2beb9 100644 --- a/ai_scientist/utils/token_tracker.py +++ b/ai_scientist/utils/token_tracker.py @@ -1,6 +1,5 @@ from functools import wraps from typing import Dict, Optional, List -import tiktoken from collections import defaultdict import asyncio from datetime import datetime @@ -196,16 +195,22 @@ def sync_wrapper(*args, **kwargs): logging.info("kwargs: ", kwargs) if hasattr(result, "usage"): + reasoning_tokens = ( + result.usage.completion_tokens_details.reasoning_tokens + if result.usage.completion_tokens_details + else 0 + ) + cached_tokens = ( + result.usage.prompt_tokens_details.cached_tokens + if hasattr(result.usage, "prompt_tokens_details") and result.usage.prompt_tokens_details is not None + else 0 + ) token_tracker.add_tokens( model, result.usage.prompt_tokens, result.usage.completion_tokens, - result.usage.completion_tokens_details.reasoning_tokens, - ( - result.usage.prompt_tokens_details.cached_tokens - if hasattr(result.usage, "prompt_tokens_details") - else 0 - ), + reasoning_tokens, + cached_tokens, ) # Add interaction details token_tracker.add_interaction( diff --git a/ai_scientist/vlm.py b/ai_scientist/vlm.py index 240015eb..948d5c58 100644 --- a/ai_scientist/vlm.py +++ b/ai_scientist/vlm.py @@ -6,8 +6,11 @@ import openai from PIL import Image from ai_scientist.utils.token_tracker import track_token_usage +import os +from dotenv import load_dotenv -MAX_NUM_TOKENS = 4096 +load_dotenv() +MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 4096)) AVAILABLE_VLMS = [ "gpt-4o-2024-05-13", @@ -67,7 +70,7 @@ def make_llm_call(client, model, temperature, system_message, prompt): @track_token_usage def make_vlm_call(client, model, temperature, system_message, prompt): - if "gpt" in model: + if "gpt" in model or "gemini" in model: return client.chat.completions.create( model=model, messages=[ @@ -107,7 +110,7 @@ def get_response_from_vlm( if msg_history is None: msg_history = [] - if model in AVAILABLE_VLMS: + if model in AVAILABLE_VLMS or "gemini" in model: # Convert single image path to list for consistent handling if isinstance(image_paths, str): image_paths = [image_paths] @@ -147,7 +150,7 @@ def get_response_from_vlm( print() print("*" * 20 + " VLM START " + "*" * 20) for j, msg in enumerate(new_msg_history): - print(f'{j}, {msg["role"]}: {msg["content"]}') + print(f"{j}, {msg['role']}: {msg['content']}") print(content) print("*" * 21 + " VLM END " + "*" * 21) print() @@ -166,6 +169,13 @@ def create_client(model: str) -> tuple[Any, str]: ]: print(f"Using OpenAI API with model {model}.") return openai.OpenAI(), model + elif "gemini" in model: + print(f"Using Gemini API with model {model}.") + return openai.OpenAI( + max_retries=0, + api_key=os.getenv("GEMINI_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ), model else: raise ValueError(f"Model {model} not supported.") @@ -242,7 +252,7 @@ def get_batch_responses_from_vlm( "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18", "o3-mini", - ]: + ] or "gemini" in model: # Convert single image path to list if isinstance(image_paths, str): image_paths = [image_paths] @@ -290,7 +300,7 @@ def get_batch_responses_from_vlm( print() print("*" * 20 + " VLM START " + "*" * 20) for j, msg in enumerate(new_msg_histories[0]): - print(f'{j}, {msg["role"]}: {msg["content"]}') + print(f"{j}, {msg['role']}: {msg['content']}") print(contents[0]) print("*" * 21 + " VLM END " + "*" * 21) print() diff --git a/bfts_config_gemini.yaml b/bfts_config_gemini.yaml new file mode 100644 index 00000000..2c2d4413 --- /dev/null +++ b/bfts_config_gemini.yaml @@ -0,0 +1,79 @@ +# path to the task data directory +data_dir: "data" +preprocess_data: False + +goal: null +eval: null + +log_dir: logs +workspace_dir: workspaces + +# whether to copy the data to the workspace directory (otherwise it will be symlinked) +# copying is recommended to prevent the agent from accidentally modifying the original data +copy_data: True + +exp_name: run # a random experiment name will be generated if not provided + +# settings for code execution +exec: + timeout: 3600 + agent_file_name: runfile.py + format_tb_ipython: False + +generate_report: True +# LLM settings for final report from journal +report: + # model: gemini-2.5-pro-exp-03-25 + model: gemini-2.0-flash-exp + temp: 1.0 + +experiment: + num_syn_datasets: 1 + +debug: + stage4: False + +# agent hyperparams +agent: + type: parallel + num_workers: 4 + stages: + stage1_max_iters: 10 + stage2_max_iters: 2 + stage3_max_iters: 2 + stage4_max_iters: 2 + # how many improvement iterations to run + steps: 2 # if stage-specific max_iters are not provided, the agent will use this value for all stages + # whether to instruct the agent to use CV (set to 1 to disable) + k_fold_validation: 1 + multi_seed_eval: + num_seeds: 3 # should be the same as num_workers if num_workers < 3. Otherwise, set it to be 3. + # whether to instruct the agent to generate a prediction function + expose_prediction: False + # whether to provide the agent with a preview of the data + data_preview: False + + # LLM settings for coding + code: + # model: gemini-2.5-pro-exp-03-25 + model: gemini-2.0-flash-exp + temp: 1.0 + max_tokens: 8192 + + # LLM settings for evaluating program output / tracebacks + feedback: + # model: gemini-2.5-pro-exp-03-25 + model: gemini-2.0-flash-exp + temp: 0.5 + max_tokens: 8192 + + vlm_feedback: + # model: gemini-2.5-pro-exp-03-25 + model: gemini-2.0-flash-exp + temp: 0.5 + max_tokens: null + + search: + max_debug_depth: 3 + debug_prob: 0.5 + num_drafts: 3 diff --git a/launch_scientist_bfts.py b/launch_scientist_bfts.py index ea8f02a7..b1063acd 100644 --- a/launch_scientist_bfts.py +++ b/launch_scientist_bfts.py @@ -26,6 +26,9 @@ from ai_scientist.perform_llm_review import perform_review, load_paper from ai_scientist.perform_vlm_review import perform_imgs_cap_ref_review from ai_scientist.utils.token_tracker import token_tracker +from dotenv import load_dotenv + +load_dotenv() def print_time(): @@ -122,6 +125,12 @@ def parse_arguments(): action="store_true", help="If set, skip the review process", ) + parser.add_argument( + "--config_path", + type=str, + default="bfts_config.yaml", + help="Path to the bfts_config.yaml file", + ) return parser.parse_args() @@ -134,6 +143,7 @@ def get_available_gpus(gpu_ids=None): def find_pdf_path_for_review(idea_dir): pdf_files = [f for f in os.listdir(idea_dir) if f.endswith(".pdf")] reflection_pdfs = [f for f in pdf_files if "reflection" in f] + pdf_path = None if reflection_pdfs: # First check if there's a final version final_pdfs = [f for f in reflection_pdfs if "final" in f.lower()] @@ -186,6 +196,9 @@ def redirect_stdout_stderr_to_file(log_file_path): ideas = json.load(f) print(f"Loaded {len(ideas)} pregenerated ideas from {args.load_ideas}") + if isinstance(ideas, dict): + # If the JSON file contains a single idea, convert it to a list + ideas = [ideas] idea = ideas[args.idea_idx] date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -240,7 +253,7 @@ def redirect_stdout_stderr_to_file(log_file_path): with open(idea_path_json, "w") as f: json.dump(ideas[args.idea_idx], f, indent=4) - config_path = "bfts_config.yaml" + config_path = args.config_path idea_config_path = edit_bfts_config_file( config_path, idea_dir, @@ -270,7 +283,7 @@ def redirect_stdout_stderr_to_file(log_file_path): small_model=args.model_citation, ) for attempt in range(args.writeup_retries): - print(f"Writeup attempt {attempt+1} of {args.writeup_retries}") + print(f"Writeup attempt {attempt + 1} of {args.writeup_retries}") if args.writeup_type == "normal": writeup_success = perform_writeup( base_folder=idea_dir, @@ -296,7 +309,7 @@ def redirect_stdout_stderr_to_file(log_file_path): if not args.skip_review and not args.skip_writeup: # Perform paper review if the paper exists pdf_path = find_pdf_path_for_review(idea_dir) - if os.path.exists(pdf_path): + if pdf_path and os.path.exists(pdf_path): print("Paper found at: ", pdf_path) paper_content = load_paper(pdf_path) client, client_model = create_client(args.model_review)