Skip to content
This repository was archived by the owner on Oct 17, 2024. It is now read-only.

[Feature] Add Gemini Client for generations #44

Merged
merged 7 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: CI

on: [push]

env:
OMP_NUM_THREADS: 2
MKL_NUM_THREADS: 2
PIP_DISABLE_PIP_VERSION_CHECK: 1

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- uses: actions/cache@v4
name: Cache pip packages
with:
path: ~/.cache/uv
key: ${{ runner.os }}-python-${{ matrix.python-version }}

- name: Install uv
run: curl -LsSf https://astral.sh/uv/install.sh | sh

- name: Install dependencies
run: uv pip install --system -r requirements-format.txt

- name: Check lint
run: make check
42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_flash_001/1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_flash_001/cot-1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_flash_001/default.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_pro_001/1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_pro_001/cot-1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions evaluated/google/gemini_1.5_pro_001/default.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_flash_001/1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_flash_001/cot-1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_flash_001/default.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_pro_001/1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_pro_001/cot-1-shot.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions generated/google/gemini_1.5_pro_001/default.jsonl

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions generator-gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse # noqa: I001
import os

import google.generativeai as genai
import pandas as pd
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tqdm import tqdm

from templates import PROMPT_STRATEGY

# TODO: generator-gemini.py to converge with generator.py
API_KEY = "..."
MODEL_NAME = "gemini-1.5-pro-001"

genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(MODEL_NAME)

safety_settings = {
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
}

parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_dir", help="Directory to save outputs", default="./generated")
args = parser.parse_args()

df_questions = pd.read_json("questions.jsonl", orient="records", encoding="utf-8-sig", lines=True)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)


@retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception))
def call_gemini_api(input_text):
"""Function to call the Gemini API and return the generated text."""
response = model.generate_content([input_text], safety_settings=safety_settings)

if not response.candidates:
raise ValueError("Invalid operation: No candidates returned in the response.")

candidate = response.candidates[0]
if not candidate.content.parts:
print(candidate)
raise ValueError("Invalid operation: No parts found in the candidate.")

return candidate.content.parts[0].text


for strategy_name, prompts in PROMPT_STRATEGY.items():

def format_single_turn_question(question):
messages = prompts + [{"role": "user", "content": question[0]}]
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
return formatted_text

single_turn_questions = df_questions["questions"].map(format_single_turn_question)
single_turn_outputs = []
for formatted_text in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(formatted_text)
single_turn_outputs.append(generated_text)

def format_double_turn_question(question, single_turn_output):
messages = prompts + [
{"role": "user", "content": question[0]},
{"role": "assistant", "content": single_turn_output},
{"role": "user", "content": question[1]},
]
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
return formatted_text

multi_turn_questions = df_questions[["questions", "id"]].apply(
lambda x: format_double_turn_question(x["questions"], single_turn_outputs[x["id"] - 1]),
axis=1,
)
multi_turn_outputs = []
for formatted_text in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"):
generated_text = call_gemini_api(formatted_text)
multi_turn_outputs.append(generated_text)

df_output = pd.DataFrame(
{
"id": df_questions["id"],
"category": df_questions["category"],
"questions": df_questions["questions"],
"outputs": list(zip(single_turn_outputs, multi_turn_outputs)),
"references": df_questions["references"],
}
)
output_path = os.path.join(args.output_dir, f"{strategy_name}.jsonl")
df_output.to_json(output_path, orient="records", lines=True, force_ascii=False)
print(f"Saved outputs to {output_path}")
1 change: 1 addition & 0 deletions requirements-format.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ruff==0.4.9