Skip to content

Commit be84f4d

Browse files
authored
Merge pull request #4 from small-thinking/oai-batch-request-e2e
Separate request and generator
2 parents 33278e4 + e62faa5 commit be84f4d

4 files changed

Lines changed: 225 additions & 74 deletions

File tree

data_gen/core/llm_requester.py

Lines changed: 49 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,26 @@
33
from typing import Optional
44

55

6+
def upload_file_to_openai(file_path: str, api_key: str, purpose: str = "batch") -> str:
7+
"""
8+
Upload a file to OpenAI and return the file ID.
9+
Args:
10+
file_path: Path to the file to upload (e.g., .jsonl)
11+
api_key: OpenAI API key
12+
purpose: Purpose of the file upload (default: "batch")
13+
Returns:
14+
The file ID assigned by OpenAI
15+
"""
16+
url = "https://api.openai.com/v1/files"
17+
headers = {"Authorization": f"Bearer {api_key}"}
18+
with open(file_path, "rb") as f:
19+
files = {"file": (file_path, f, "application/jsonl")}
20+
data = {"purpose": purpose}
21+
response = requests.post(url, headers=headers, files=files, data=data)
22+
response.raise_for_status()
23+
return response.json()["id"]
24+
25+
626
class LLMRequester:
727
"""Abstract base class for LLM API requests."""
828
def request(self, prompt: str) -> str:
@@ -24,36 +44,10 @@ def request_batch(self, input_file_path: str) -> str:
2444
pass
2545

2646

27-
class OpenAIRequester(LLMRequester):
28-
"""Concrete implementation of LLMRequester for OpenAI API single requests."""
29-
30-
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
31-
self.api_key = api_key
32-
self.model = model
33-
self.api_url = "https://api.openai.com/v1/chat/completions"
34-
35-
def request(self, prompt: str) -> str:
36-
"""Send a prompt to the OpenAI API and return the response as a string."""
37-
headers = {
38-
"Authorization": f"Bearer {self.api_key}",
39-
"Content-Type": "application/json"
40-
}
41-
data = {
42-
"model": self.model,
43-
"messages": [
44-
{"role": "user", "content": prompt}
45-
]
46-
}
47-
response = requests.post(self.api_url, headers=headers, json=data)
48-
response.raise_for_status()
49-
result = response.json()
50-
return result["choices"][0]["message"]["content"]
51-
52-
5347
class OpenAIBatchRequester(LLMRequester):
5448
"""Concrete implementation of LLMRequester for OpenAI API batch requests."""
5549

56-
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
50+
def __init__(self, api_key: str, model: str = "gpt-4.1-mini"):
5751
self.api_key = api_key
5852
self.model = model
5953
self.files_url = "https://api.openai.com/v1/files"
@@ -62,14 +56,36 @@ def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
6256
"Authorization": f"Bearer {self.api_key}"
6357
}
6458

59+
def request_batch(self, input_file_path: str, output_path: Optional[str] = None) -> str:
60+
"""
61+
Send a batch of prompts to the OpenAI API using a file.
62+
Args:
63+
input_file_path: Path to the file containing batch prompts (e.g., .jsonl)
64+
output_path: Where to save the output file (optional)
65+
Returns:
66+
Path to the output file or batch job ID if not completed
67+
"""
68+
# 1. Upload file
69+
input_file_id = self._upload_file(input_file_path)
70+
# 2. Create batch
71+
batch_id = self._create_batch(input_file_id)
72+
# 3. Poll for status
73+
batch_obj = self._poll_batch_status(batch_id)
74+
status = batch_obj.get("status")
75+
if status != "completed":
76+
return batch_id # Return batch ID for further status checking
77+
# 4. Download output file
78+
output_file_id = batch_obj.get("output_file_id")
79+
if not output_file_id:
80+
raise RuntimeError("Batch completed but no output_file_id found.")
81+
if output_path is None:
82+
output_path = f"batch_output_{batch_id}.jsonl"
83+
self._download_output_file(output_file_id, output_path)
84+
return output_path
85+
6586
def _upload_file(self, file_path: str) -> str:
6687
"""Upload a .jsonl file to OpenAI and return the file ID."""
67-
with open(file_path, "rb") as f:
68-
files = {"file": (file_path, f, "application/jsonl")}
69-
data = {"purpose": "batch"}
70-
response = requests.post(self.files_url, headers=self.headers, files=files, data=data)
71-
response.raise_for_status()
72-
return response.json()["id"]
88+
return upload_file_to_openai(file_path, self.api_key, purpose="batch")
7389

7490
def _create_batch(self, input_file_id: str) -> str:
7591
"""Create a batch job and return the batch ID."""
@@ -109,30 +125,3 @@ def _download_output_file(self, output_file_id: str, output_path: str) -> None:
109125
with open(output_path, "wb") as f:
110126
for chunk in response.iter_content(chunk_size=8192):
111127
f.write(chunk)
112-
113-
def request_batch(self, input_file_path: str, output_path: Optional[str] = None) -> str:
114-
"""
115-
Send a batch of prompts to the OpenAI API using a file.
116-
Args:
117-
input_file_path: Path to the file containing batch prompts (e.g., .jsonl)
118-
output_path: Where to save the output file (optional)
119-
Returns:
120-
Path to the output file or batch job ID if not completed
121-
"""
122-
# 1. Upload file
123-
input_file_id = self._upload_file(input_file_path)
124-
# 2. Create batch
125-
batch_id = self._create_batch(input_file_id)
126-
# 3. Poll for status
127-
batch_obj = self._poll_batch_status(batch_id)
128-
status = batch_obj.get("status")
129-
if status != "completed":
130-
return batch_id # Return batch ID for further status checking
131-
# 4. Download output file
132-
output_file_id = batch_obj.get("output_file_id")
133-
if not output_file_id:
134-
raise RuntimeError("Batch completed but no output_file_id found.")
135-
if output_path is None:
136-
output_path = f"batch_output_{batch_id}.jsonl"
137-
self._download_output_file(output_file_id, output_path)
138-
return output_path

data_gen/core/tests/test_generator.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@
77
import pytest
88
from typing import List, Dict, Any
99
from ..generator import DataGenerator
10-
from ..llm_requester import LLMRequester
11-
12-
13-
class MockLLMRequester(LLMRequester):
14-
def request(self, prompt: str) -> str:
15-
# Return a simple JSON string for testing
16-
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt
17-
18-
19-
class MockBatchLLMRequester(LLMRequester):
20-
def request_batch(self, input_file_path: str) -> str:
21-
return "mock_batch_id"
2210

2311

2412
class TestDataGenerator(DataGenerator):
@@ -39,7 +27,11 @@ def parse_llm_response(self, response: str) -> Dict[str, Any]:
3927

4028
@pytest.fixture
4129
def generator():
42-
llm_requester = MockLLMRequester()
30+
# Use a simple mock for LLMRequester
31+
class DummyLLMRequester:
32+
def request(self, prompt: str) -> str:
33+
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt
34+
llm_requester = DummyLLMRequester()
4335
return TestDataGenerator(llm_requester)
4436

4537

@@ -74,4 +66,4 @@ def test_parse_llm_response(generator):
7466
response = '{"mock_key": "mock_value", "prompt": "Prompt for test"}'
7567
parsed = generator.parse_llm_response(response)
7668
assert parsed["mock_key"] == "mock_value"
77-
assert parsed["prompt"] == "Prompt for test"
69+
assert parsed["prompt"] == "Prompt for test"
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Tests for LLMRequester and related utilities.
3+
Run with:
4+
uv run pytest data_gen/core/tests/test_llm_requester.py
5+
"""
6+
7+
from data_gen.core.llm_requester import LLMRequester, upload_file_to_openai
8+
9+
10+
class MockLLMRequester(LLMRequester):
11+
def request(self, prompt: str) -> str:
12+
# Return a simple JSON string for testing
13+
return '{"mock_key": "mock_value", "prompt": "%s"}' % prompt
14+
15+
16+
class MockBatchLLMRequester(LLMRequester):
17+
def request_batch(self, input_file_path: str) -> str:
18+
return "mock_batch_id"
19+
20+
21+
def test_upload_file_to_openai(monkeypatch):
22+
"""
23+
Test upload_file_to_openai to ensure it sends the correct request and parses the file ID.
24+
Uses monkeypatch to mock requests.post and avoid real API calls.
25+
"""
26+
import requests
27+
28+
class DummyResponse:
29+
def raise_for_status(self) -> None:
30+
pass
31+
def json(self) -> dict:
32+
return {"id": "file-1234"}
33+
34+
def mock_post(url: str, headers: dict, files: dict, data: dict):
35+
assert url == "https://api.openai.com/v1/files"
36+
assert headers["Authorization"].startswith("Bearer ")
37+
assert data["purpose"] == "batch"
38+
assert "file" in files
39+
return DummyResponse()
40+
41+
monkeypatch.setattr(requests, "post", mock_post)
42+
43+
# Create a dummy file
44+
dummy_path = "dummy.jsonl"
45+
with open(dummy_path, "w") as f:
46+
f.write("{}\n")
47+
try:
48+
file_id = upload_file_to_openai(dummy_path, api_key="sk-test")
49+
assert file_id == "file-1234"
50+
finally:
51+
import os
52+
os.remove(dummy_path)

0 commit comments

Comments
 (0)