Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 49 additions & 60 deletions data_gen/core/llm_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
from typing import Optional


def upload_file_to_openai(file_path: str, api_key: str, purpose: str = "batch") -> str:
"""
Upload a file to OpenAI and return the file ID.
Args:
file_path: Path to the file to upload (e.g., .jsonl)
api_key: OpenAI API key
purpose: Purpose of the file upload (default: "batch")
Returns:
The file ID assigned by OpenAI
"""
url = "https://api.openai.com/v1/files"
headers = {"Authorization": f"Bearer {api_key}"}
with open(file_path, "rb") as f:
files = {"file": (file_path, f, "application/jsonl")}
data = {"purpose": purpose}
response = requests.post(url, headers=headers, files=files, data=data)
response.raise_for_status()
return response.json()["id"]


class LLMRequester:
"""Abstract base class for LLM API requests."""
def request(self, prompt: str) -> str:
Expand All @@ -24,36 +44,10 @@ def request_batch(self, input_file_path: str) -> str:
pass


class OpenAIRequester(LLMRequester):
"""Concrete implementation of LLMRequester for OpenAI API single requests."""

def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
self.api_key = api_key
self.model = model
self.api_url = "https://api.openai.com/v1/chat/completions"

def request(self, prompt: str) -> str:
"""Send a prompt to the OpenAI API and return the response as a string."""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"messages": [
{"role": "user", "content": prompt}
]
}
response = requests.post(self.api_url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]


class OpenAIBatchRequester(LLMRequester):
"""Concrete implementation of LLMRequester for OpenAI API batch requests."""

def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
def __init__(self, api_key: str, model: str = "gpt-4.1-mini"):
self.api_key = api_key
self.model = model
self.files_url = "https://api.openai.com/v1/files"
Expand All @@ -62,14 +56,36 @@ def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
"Authorization": f"Bearer {self.api_key}"
}

def request_batch(self, input_file_path: str, output_path: Optional[str] = None) -> str:
"""
Send a batch of prompts to the OpenAI API using a file.
Args:
input_file_path: Path to the file containing batch prompts (e.g., .jsonl)
output_path: Where to save the output file (optional)
Returns:
Path to the output file or batch job ID if not completed
"""
# 1. Upload file
input_file_id = self._upload_file(input_file_path)
# 2. Create batch
batch_id = self._create_batch(input_file_id)
# 3. Poll for status
batch_obj = self._poll_batch_status(batch_id)
status = batch_obj.get("status")
if status != "completed":
return batch_id # Return batch ID for further status checking
# 4. Download output file
output_file_id = batch_obj.get("output_file_id")
if not output_file_id:
raise RuntimeError("Batch completed but no output_file_id found.")
if output_path is None:
output_path = f"batch_output_{batch_id}.jsonl"
self._download_output_file(output_file_id, output_path)
return output_path

def _upload_file(self, file_path: str) -> str:
"""Upload a .jsonl file to OpenAI and return the file ID."""
with open(file_path, "rb") as f:
files = {"file": (file_path, f, "application/jsonl")}
data = {"purpose": "batch"}
response = requests.post(self.files_url, headers=self.headers, files=files, data=data)
response.raise_for_status()
return response.json()["id"]
return upload_file_to_openai(file_path, self.api_key, purpose="batch")

def _create_batch(self, input_file_id: str) -> str:
"""Create a batch job and return the batch ID."""
Expand Down Expand Up @@ -109,30 +125,3 @@ def _download_output_file(self, output_file_id: str, output_path: str) -> None:
with open(output_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

def request_batch(self, input_file_path: str, output_path: Optional[str] = None) -> str:
"""
Send a batch of prompts to the OpenAI API using a file.
Args:
input_file_path: Path to the file containing batch prompts (e.g., .jsonl)
output_path: Where to save the output file (optional)
Returns:
Path to the output file or batch job ID if not completed
"""
# 1. Upload file
input_file_id = self._upload_file(input_file_path)
# 2. Create batch
batch_id = self._create_batch(input_file_id)
# 3. Poll for status
batch_obj = self._poll_batch_status(batch_id)
status = batch_obj.get("status")
if status != "completed":
return batch_id # Return batch ID for further status checking
# 4. Download output file
output_file_id = batch_obj.get("output_file_id")
if not output_file_id:
raise RuntimeError("Batch completed but no output_file_id found.")
if output_path is None:
output_path = f"batch_output_{batch_id}.jsonl"
self._download_output_file(output_file_id, output_path)
return output_path
20 changes: 6 additions & 14 deletions data_gen/core/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@
import pytest
from typing import List, Dict, Any
from ..generator import DataGenerator
from ..llm_requester import LLMRequester


class MockLLMRequester(LLMRequester):
def request(self, prompt: str) -> str:
# Return a simple JSON string for testing
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt


class MockBatchLLMRequester(LLMRequester):
def request_batch(self, input_file_path: str) -> str:
return "mock_batch_id"


class TestDataGenerator(DataGenerator):
Expand All @@ -39,7 +27,11 @@ def parse_llm_response(self, response: str) -> Dict[str, Any]:

@pytest.fixture
def generator():
llm_requester = MockLLMRequester()
# Use a simple mock for LLMRequester
class DummyLLMRequester:
def request(self, prompt: str) -> str:
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt
llm_requester = DummyLLMRequester()
return TestDataGenerator(llm_requester)


Expand Down Expand Up @@ -74,4 +66,4 @@ def test_parse_llm_response(generator):
response = '{"mock_key": "mock_value", "prompt": "Prompt for test"}'
parsed = generator.parse_llm_response(response)
assert parsed["mock_key"] == "mock_value"
assert parsed["prompt"] == "Prompt for test"
assert parsed["prompt"] == "Prompt for test"
52 changes: 52 additions & 0 deletions data_gen/core/tests/test_llm_requester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Tests for LLMRequester and related utilities.
Run with:
uv run pytest data_gen/core/tests/test_llm_requester.py
"""

from data_gen.core.llm_requester import LLMRequester, upload_file_to_openai


class MockLLMRequester(LLMRequester):
def request(self, prompt: str) -> str:
# Return a simple JSON string for testing
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt


class MockBatchLLMRequester(LLMRequester):
def request_batch(self, input_file_path: str) -> str:
return "mock_batch_id"


def test_upload_file_to_openai(monkeypatch):
"""
Test upload_file_to_openai to ensure it sends the correct request and parses the file ID.
Uses monkeypatch to mock requests.post and avoid real API calls.
"""
import requests

class DummyResponse:
def raise_for_status(self) -> None:
pass
def json(self) -> dict:
return {"id": "file-1234"}

def mock_post(url: str, headers: dict, files: dict, data: dict):
assert url == "https://api.openai.com/v1/files"
assert headers["Authorization"].startswith("Bearer ")
assert data["purpose"] == "batch"
assert "file" in files
return DummyResponse()

monkeypatch.setattr(requests, "post", mock_post)

# Create a dummy file
dummy_path = "dummy.jsonl"
with open(dummy_path, "w") as f:
f.write("{}\n")
try:
file_id = upload_file_to_openai(dummy_path, api_key="sk-test")
assert file_id == "file-1234"
finally:
import os
os.remove(dummy_path)
Loading
Loading