-
Notifications
You must be signed in to change notification settings - Fork 5
[DRAFT] Dev/OpenAI #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DRAFT] Dev/OpenAI #263
Changes from 2 commits
154b8c6
bfcaac0
5c16452
a181403
f77d942
4f52aab
a147b9f
37cbfc1
04c2caa
8eabfa1
9f41886
658c2b4
7943017
3bab0fd
e10d688
94aba33
cec7903
2260e25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| from typing import Union, Optional, Literal | ||
| from functools import partial | ||
| from copy import deepcopy | ||
| from openai import OpenAI | ||
|
|
||
|
|
||
| from align_system.algorithms.abstracts import StructuredInferenceEngine | ||
|
|
||
|
|
||
| class VLLMInferenceEngine(StructuredInferenceEngine): | ||
|
|
||
| # TODO Either create a second class VLLMInferenceEngine or two different params classes | ||
| class OpenAIInferenceEngine(StructuredInferenceEngine): | ||
| def __init__(self, | ||
| model_name: str, | ||
| temperature: float, | ||
| top_p: float, | ||
| max_tokens: int, | ||
| inference_batch_size: int | ||
| base_url: Optional[str] = None, | ||
| api_key: Optional[str] = None, | ||
| organization: Optional[str] = None, | ||
| project: Optional[str] = None, | ||
| webhook_secret: Optional[str] = None, | ||
| _strict_response_validation: bool = False, | ||
| timeout: Union[float, None, Literal["NOT_GIVEN"]] = "NOT_GIVEN" | ||
| ): | ||
|
|
||
| self.model_name = model_name | ||
| self.temperature = temperature | ||
| self.top_p = top_p | ||
| self.inference_batch_size = inference_batch_size | ||
|
|
||
| # Delete if VLLM does not care about the presence of an api key | ||
| # _api_key = os.environ.get("OPENAI_API_KEY") if not (base_url or api_key) else api_key | ||
ygefen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.client = OpenAI( | ||
| api_key=api_key, | ||
| organization=organization, | ||
| project=project, | ||
| webhook_secret=webhook_secret, | ||
| base_url=base_url, | ||
| timeout=timeout, | ||
| _strict_response_validation=_strict_response_validation | ||
| ) | ||
|
|
||
| self.responses_kwargs = { | ||
| "model": self.model, | ||
| "reasoning": {"effort": "medium", "summary": "auto"}, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These might make more sense to go in the init args for the class, but fine to leave here for this first version. |
||
| "max_output_tokens": self.max_tokens, | ||
| "temperature": self.temperature, | ||
| "top_p": self.top_p, | ||
| "service_tier":"flex", | ||
| "store":True | ||
| } | ||
ygefen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def dialog_to_prompt(self, dialog: list[dict]) -> str: | ||
| # OpenAI uses "developer" prompt instread of "system" (which is not exposed to the caller) | ||
| # https://platform.openai.com/docs/guides/prompt-engineering#message-roles-and-instruction-following | ||
| if not self.base_url: | ||
| # We are targetting the OpenAI service | ||
| prompt = deepcopy(dialog) | ||
| for p in prompt: | ||
| if p["role"] == "system": | ||
| p["role"] = "developer" | ||
| return prompt | ||
ygefen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def run_inference(self, prompts: Union[str, list[str]], schema: str) -> Union[dict, list[dict]]: | ||
ygefen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self.client.responses.create( | ||
| input=prompts, | ||
| text=self.text_field(schema) | ||
| **self.responses_kwargs | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def text_field(schema): | ||
| return { | ||
| "format": { | ||
| "type": "json_schema", | ||
| "name": "ITM Schema", | ||
| "schema": schema, | ||
| "strict": True | ||
| } | ||
| } | ||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| _target_: align_system.algorithms.openai_inference_engine.OpenAIInferenceEngine | ||
|
|
||
| # https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L99 | ||
| model_name: gpt-4 | ||
| temperature: 0.7 | ||
| top_p: 0.9 | ||
| max_tokens: 4096 | ||
| inference_batch_size: 5 | ||
|
|
||
| # Optional kwargs | ||
| base_url: null # Set to null for OpenAI cloud | ||
| api_key: null # Reads from OPENAI_API_KEY env var if null | ||
| organization: null # Reads from OPENAI_ORG_ID env var if null | ||
| project: null # Reads from OPENAI_PROJECT_ID env var if null | ||
| webhook_secret: null # Reads from OPENAI_WEBHOOK_SECRET env var if null | ||
| timeout: "NOT_GIVEN" | ||
| # Optional unstable kwargs | ||
ygefen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| _strict_response_validation: True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you were able to successfully run against a local vLLM instance with this config (and the openai api inference code?) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| _target_: align_system.algorithms.openai_inference_engine.OpenAIInferenceEngine | ||
|
|
||
| model_name: meta-llama/Llama-3.1-8B-Instruct | ||
| base_url: http://localhost:8000/v1 | ||
| api_key: null # vLLM doesn't validate API keys | ||
| temperature: 0.3 | ||
| top_p: 0.9 | ||
| max_tokens: 4096 | ||
| inference_batch_size: 5 |
Uh oh!
There was an error while loading. Please reload this page.