Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
HF_TOKEN=""
OPENAI_API_KEY="sk-proj-"
GEMINI_API_KEY=""
S2_API_KEY=""
MAX_NUM_TOKENS=4096
LLM_MODEL="gemini-2.0-flash-exp"
LLM_SMALL_MODEL="gemini-2.0-flash-exp"
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ conda install conda-forge::chktex
pip install -r requirements.txt
```

You may also need to install pdflatex like here https://gist.github.com/rain1024/98dd5e2c6c8c28f9ea9d

### Supported Models and API Keys

#### OpenAI Models

By default, the system uses the `OPENAI_API_KEY` environment variable for OpenAI models.
By default, the system uses the `OPENAI_API_KEY` environment variable. You can see other environment variables in `.env.example`.


#### Claude Models via AWS Bedrock

Expand Down Expand Up @@ -148,6 +151,20 @@ python launch_scientist_bfts.py \

Once the initial experimental stage is complete, you will find a timestamped log folder inside the `experiments/` directory. Navigate to `experiments/"timestamp_ideaname"/logs/0-run/` within that folder to find the tree visualization file `unified_tree_viz.html`.

### Running with Gemini
Set `GEMINI_API_KEY` environment variable to your Gemini API key, and set `LLM_MODEL` to a gemini model like `gemini-2.0-flash-exp`. You can use the following command to run the experiment with the Gemini model:

```bash
python launch_scientist_bfts.py \
--load_ideas "ai_scientist/ideas/automated_concept_sae_eval.json" \
--add_dataset_ref \
--model_writeup "gemini-2.0-flash-exp" \
--model_citation "gemini-2.0-flash-exp" \
--model_review "gemini-2.0-flash-exp" \
--model_agg_plots "gemini-2.0-flash-exp" \
--num_cite_rounds 20 \
--config_path "bfts_config_gemini.yaml"
```
## Citing The AI Scientist-v2

If you use **The AI Scientist-v2** in your research, please cite our work as follows:
Expand Down
38 changes: 23 additions & 15 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import anthropic
import backoff
import openai
from dotenv import load_dotenv

MAX_NUM_TOKENS = 4096
load_dotenv()
MAX_NUM_TOKENS = int(os.environ.get("MAX_NUM_TOKENS", 4096))

AVAILABLE_LLMS = [
"claude-3-5-sonnet-20240620",
Expand Down Expand Up @@ -147,7 +149,7 @@ def get_batch_responses_from_llm(
print()
print("*" * 20 + " LLM START " + "*" * 20)
for j, msg in enumerate(new_msg_history[0]):
print(f'{j}, {msg["role"]}: {msg["content"]}')
print(f"{j}, {msg['role']}: {msg['content']}")
print(content)
print("*" * 21 + " LLM END " + "*" * 21)
print()
Expand All @@ -170,6 +172,16 @@ def make_llm_call(client, model, temperature, system_message, prompt):
stop=None,
seed=0,
)
elif "gemini" in model:
return client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_message},
*prompt,
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
)
elif "o1" in model or "o3" in model:
return client.chat.completions.create(
model=model,
Expand Down Expand Up @@ -239,18 +251,7 @@ def get_response_from_llm(
],
}
]
elif "gpt" in model:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = make_llm_call(
client,
model,
temperature,
system_message=system_message,
prompt=new_msg_history,
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif "o1" in model or "o3" in model:
elif any(x in model for x in ["gpt", "gemini", "o1", "o3"]):
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = make_llm_call(
client,
Expand Down Expand Up @@ -342,7 +343,7 @@ def get_response_from_llm(
print()
print("*" * 20 + " LLM START " + "*" * 20)
for j, msg in enumerate(new_msg_history):
print(f'{j}, {msg["role"]}: {msg["content"]}')
print(f"{j}, {msg['role']}: {msg['content']}")
print(content)
print("*" * 21 + " LLM END " + "*" * 21)
print()
Expand Down Expand Up @@ -396,6 +397,13 @@ def create_client(model) -> tuple[Any, str]:
elif "o1" in model or "o3" in model:
print(f"Using OpenAI API with model {model}.")
return openai.OpenAI(), model
elif "gemini" in model:
print(f"Using Gemini API with model {model}.")
return openai.OpenAI(
max_retries=0,
api_key=os.getenv("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
), model
elif model == "deepseek-coder-v2-0724":
print(f"Using OpenAI API with {model}.")
return (
Expand Down
19 changes: 10 additions & 9 deletions ai_scientist/perform_icbinb_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
import subprocess
import traceback
import unicodedata
import uuid
import tempfile

from dotenv import load_dotenv
load_dotenv()
from ai_scientist.llm import (
get_response_from_llm,
extract_json_between_markers,
create_client,
AVAILABLE_LLMS,
)

from ai_scientist.utils.token_tracker import track_token_usage

from ai_scientist.tools.semantic_scholar import search_for_papers

Expand Down Expand Up @@ -733,6 +732,8 @@ def filter_experiment_summaries(exp_summaries, step_name):
elif stage_name == "ABLATION_SUMMARY" and step_name == "plot_aggregation":
filtered_summaries[stage_name] = {}
for ablation_summary in exp_summaries[stage_name]:
if ablation_summary is None:
continue
filtered_summaries[stage_name][ablation_summary["ablation_name"]] = {}
for node_key in ablation_summary.keys():
if node_key in node_keys_to_keep:
Expand All @@ -742,7 +743,7 @@ def filter_experiment_summaries(exp_summaries, step_name):
return filtered_summaries


def gather_citations(base_folder, num_cite_rounds=20, small_model="gpt-4o-2024-05-13"):
def gather_citations(base_folder, num_cite_rounds=20, small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13")):
"""
Gather citations for a paper, with ability to resume from previous progress.

Expand Down Expand Up @@ -859,8 +860,8 @@ def perform_writeup(
citations_text=None,
no_writing=False,
num_cite_rounds=20,
small_model="gpt-4o-2024-05-13",
big_model="o1-2024-12-17",
small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13"),
big_model=os.getenv("LLM_MODEL", "o1-2024-12-17"),
n_writeup_reflections=3,
page_limit=4,
):
Expand Down Expand Up @@ -950,7 +951,7 @@ def perform_writeup(

# Generate VLM-based descriptions
try:
vlm_client, vlm_model = create_vlm_client("gpt-4o-2024-05-13")
vlm_client, vlm_model = create_vlm_client(os.getenv("LLM_MODEL", "gpt-4o-2024-05-13"))
desc_map = {}
for pf in plot_names:
ppath = osp.join(figures_dir, pf)
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def perform_writeup(
base_folder, f"{osp.basename(base_folder)}_reflection_final_page_limit.pdf"
)
# Compile current version before reflection
print(f"[green]Compiling PDF for reflection final page limit...[/green]")
print("[green]Compiling PDF for reflection final page limit...[/green]")

print(f"reflection step {i+1}")

Expand All @@ -1232,7 +1233,7 @@ def perform_writeup(

compile_latex(latex_folder, reflection_pdf)
else:
print(f"No changes in reflection page step.")
print("No changes in reflection page step.")

return osp.exists(reflection_pdf)

Expand Down
13 changes: 8 additions & 5 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
create_client,
AVAILABLE_LLMS,
)

from ai_scientist.tools.semantic_scholar import search_for_papers

from ai_scientist.perform_vlm_review import generate_vlm_img_review
from ai_scientist.vlm import create_client as create_vlm_client

Expand Down Expand Up @@ -456,8 +454,8 @@ def perform_writeup(
base_folder,
no_writing=False,
num_cite_rounds=20,
small_model="gpt-4o-2024-05-13",
big_model="o1-2024-12-17",
small_model=os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13"),
big_model=os.getenv("LLM_MODEL", "o1-2024-12-17"),
n_writeup_reflections=3,
page_limit=8,
):
Expand Down Expand Up @@ -525,6 +523,11 @@ def perform_writeup(
for fplot in os.listdir(figures_dir):
if fplot.lower().endswith(".png"):
plot_names.append(fplot)
# copy plots to ./latex folder
for fplot in plot_names:
src_path = osp.join(figures_dir, fplot)
dest_path = osp.join(latex_folder, fplot)
shutil.copy(src_path, dest_path)

# Load aggregator script to include in the prompt
aggregator_path = osp.join(base_folder, "auto_plot_aggregator.py")
Expand Down Expand Up @@ -589,7 +592,7 @@ def perform_writeup(

# Generate VLM-based descriptions but do not overwrite plot_names
try:
vlm_client, vlm_model = create_vlm_client("gpt-4o-2024-05-13")
vlm_client, vlm_model = create_vlm_client(os.getenv("LLM_SMALL_MODEL", "gpt-4o-2024-05-13"))
desc_map = {}
for pf in plot_names:
ppath = osp.join(figures_dir, pf)
Expand Down
1 change: 1 addition & 0 deletions ai_scientist/treesearch/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ def _create_stage_analysis_prompt(
# Save stage transition analysis to notes directory
base_dir = Path(self.workspace_dir).parent.parent
run_name = Path(self.workspace_dir).name
stage_number = previous_stages[-1].stage_number
notes_dir = (
base_dir
/ "logs"
Expand Down
27 changes: 20 additions & 7 deletions ai_scientist/treesearch/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from funcy import notnone, once, select_values
import openai
from rich import print
import os
from dotenv import load_dotenv

load_dotenv()
logger = logging.getLogger("ai-scientist")

_client: openai.OpenAI = None # type: ignore
Expand All @@ -22,7 +25,17 @@
@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)
gemini_api_key = os.getenv("GEMINI_API_KEY")
if gemini_api_key is None:
_client = openai.OpenAI(
max_retries=0,
)
else:
_client = openai.OpenAI(
max_retries=0,
api_key=os.getenv("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)


def query(
Expand Down Expand Up @@ -55,12 +68,12 @@ def query(
if func_spec is None:
output = choice.message.content
else:
assert (
choice.message.tool_calls
), f"function_call is empty, it is not a function call: {choice.message}"
assert (
choice.message.tool_calls[0].function.name == func_spec.name
), "Function name mismatch"
assert choice.message.tool_calls, (
f"function_call is empty, it is not a function call: {choice.message}"
)
assert choice.message.tool_calls[0].function.name == func_spec.name, (
"Function name mismatch"
)
try:
print(f"[cyan]Raw func call response: {choice}[/cyan]")
output = json.loads(choice.message.tool_calls[0].function.arguments)
Expand Down
5 changes: 1 addition & 4 deletions ai_scientist/treesearch/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ def opt_messages_to_list(
system_message: str | None, user_message: str | None
) -> list[dict[str, str]]:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if user_message:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "user", "content": (system_message or "") + "\n" + (user_message or "")})
return messages


Expand Down
Loading