Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions align_system/algorithms/openai_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Union, Optional, Literal
from functools import partial
from copy import deepcopy
from openai import OpenAI


from align_system.algorithms.abstracts import StructuredInferenceEngine


class VLLMInferenceEngine(StructuredInferenceEngine):

# TODO Either create a second class VLLMInferenceEngine or two different params classes
class OpenAIInferenceEngine(StructuredInferenceEngine):
def __init__(self,
model_name: str,
temperature: float,
top_p: float,
max_tokens: int,
inference_batch_size: int
base_url: Optional[str] = None,
api_key: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
webhook_secret: Optional[str] = None,
_strict_response_validation: bool = False,
timeout: Union[float, None, Literal["NOT_GIVEN"]] = "NOT_GIVEN"
):

self.model_name = model_name
self.temperature = temperature
self.top_p = top_p
self.inference_batch_size = inference_batch_size

# Delete if VLLM does not care about the presence of an api key
# _api_key = os.environ.get("OPENAI_API_KEY") if not (base_url or api_key) else api_key

self.client = OpenAI(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
base_url=base_url,
timeout=timeout,
_strict_response_validation=_strict_response_validation
)

self.responses_kwargs = {
"model": self.model,
"reasoning": {"effort": "medium", "summary": "auto"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These might make more sense to go in the init args for the class, but fine to leave here for this first version.

"max_output_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"service_tier":"flex",
"store":True
}

def dialog_to_prompt(self, dialog: list[dict]) -> str:
# OpenAI uses "developer" prompt instread of "system" (which is not exposed to the caller)
# https://platform.openai.com/docs/guides/prompt-engineering#message-roles-and-instruction-following
if not self.base_url:
# We are targetting the OpenAI service
prompt = deepcopy(dialog)
for p in prompt:
if p["role"] == "system":
p["role"] = "developer"
return prompt


def run_inference(self, prompts: Union[str, list[str]], schema: str) -> Union[dict, list[dict]]:
return self.client.responses.create(
input=prompts,
text=self.text_field(schema)
**self.responses_kwargs
)

@staticmethod
def text_field(schema):
return {
"format": {
"type": "json_schema",
"name": "ITM Schema",
"schema": schema,
"strict": True
}
}




18 changes: 18 additions & 0 deletions align_system/configs/inference_engine/openai_gpt4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_target_: align_system.algorithms.openai_inference_engine.OpenAIInferenceEngine

# https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L99
model_name: gpt-4
temperature: 0.7
top_p: 0.9
max_tokens: 4096
inference_batch_size: 5

# Optional kwargs
base_url: null # Set to null for OpenAI cloud
api_key: null # Reads from OPENAI_API_KEY env var if null
organization: null # Reads from OPENAI_ORG_ID env var if null
project: null # Reads from OPENAI_PROJECT_ID env var if null
webhook_secret: null # Reads from OPENAI_WEBHOOK_SECRET env var if null
timeout: "NOT_GIVEN"
# Optional unstable kwargs
_strict_response_validation: True
9 changes: 9 additions & 0 deletions align_system/configs/inference_engine/vllm_openai_llama3.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you were able to successfully run against a local vLLM instance with this config (and the openai api inference code?)

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: align_system.algorithms.openai_inference_engine.OpenAIInferenceEngine

model_name: meta-llama/Llama-3.1-8B-Instruct
base_url: http://localhost:8000/v1
api_key: null # vLLM doesn't validate API keys
temperature: 0.3
top_p: 0.9
max_tokens: 4096
inference_batch_size: 5
Loading