Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 69 additions & 43 deletions lib/instructor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,61 @@ defmodule Instructor do
changeset
end

@spec prepare_prompt(Keyword.t()) :: map()
def prepare_prompt(params, config \\ nil) do
response_model = Keyword.fetch!(params, :response_model)
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, response_model, params)

adapter(config).prompt(params)
end

@spec consume_response(any(), Keyword.t()) ::
{:ok, map()} | {:error, String.t()} | {:error, Ecto.Changeset.t(), Keyword.t()}
def consume_response(response, params) do
validation_context = Keyword.get(params, :validation_context, %{})
response_model = Keyword.fetch!(params, :response_model)
mode = Keyword.get(params, :mode, :tools)

model =
if is_ecto_schema(response_model) do
response_model.__struct__()
else
{%{}, response_model}
end

with {:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
changeset <- cast_all(model, params),
{:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
{:validation, call_validate(response_model, changeset, validation_context), response} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}

{:validation, changeset, response} ->
errors = Instructor.ErrorFormatter.format_errors(changeset)

params =
Keyword.update(params, :messages, [], fn messages ->
messages ++
echo_response(response) ++
[
%{
role: "system",
content: """
The response did not pass validation. Please try again and fix the following validation errors:\n

#{errors}
"""
}
]
end)

{:error, changeset, params}
end
end

defp do_streaming_partial_array_chat_completion(response_model, params, config) do
wrapped_model = %{
value:
Expand All @@ -270,7 +325,7 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)
prompt = prepare_prompt(params, config)

model =
if is_ecto_schema(response_model) do
Expand All @@ -279,7 +334,7 @@ defmodule Instructor do
{%{}, response_model}
end

adapter(config).chat_completion(params, config)
adapter(config).chat_completion(prompt, params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
Expand Down Expand Up @@ -341,9 +396,9 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)
prompt = prepare_prompt(params, config)

adapter(config).chat_completion(params, config)
adapter(config).chat_completion(prompt, params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
Expand Down Expand Up @@ -389,9 +444,10 @@ defmodule Instructor do
params = Keyword.put(params, :response_model, wrapped_model)
validation_context = Keyword.get(params, :validation_context, %{})
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, wrapped_model, params)

adapter(config).chat_completion(params, config)
prompt = prepare_prompt(params, config)

adapter(config).chat_completion(prompt, params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Jaxon.Stream.from_enumerable()
|> Jaxon.Stream.query([:root, "value", :all])
Expand All @@ -416,56 +472,26 @@ defmodule Instructor do
end

defp do_chat_completion(response_model, params, config) do
validation_context = Keyword.get(params, :validation_context, %{})
max_retries = Keyword.get(params, :max_retries)
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, response_model, params)
prompt = prepare_prompt(params, config)

model =
if is_ecto_schema(response_model) do
response_model.__struct__()
else
{%{}, response_model}
end

with {:llm, {:ok, response}} <- {:llm, adapter(config).chat_completion(params, config)},
{:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
changeset <- cast_all(model, params),
{:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
{:validation, call_validate(response_model, changeset, validation_context), response} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
with {:llm, {:ok, response}} <-
{:llm, adapter(config).chat_completion(prompt, params, config)},
{:ok, result} <- consume_response(response, params) do
{:ok, result}
else
{:llm, {:error, error}} ->
{:error, "LLM Adapter Error: #{inspect(error)}"}

{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}

{:validation, changeset, response} ->
{:error, changeset, new_params} ->
if max_retries > 0 do
errors = Instructor.ErrorFormatter.format_errors(changeset)

Logger.debug("Retrying LLM call for #{inspect(response_model)}:\n\n #{inspect(errors)}",
errors: errors
)

params =
params
|> Keyword.put(:max_retries, max_retries - 1)
|> Keyword.update(:messages, [], fn messages ->
messages ++
echo_response(response) ++
[
%{
role: "system",
content: """
The response did not pass validation. Please try again and fix the following validation errors:\n

#{errors}
"""
}
]
end)
params = Keyword.put(new_params, :max_retries, max_retries - 1)

do_chat_completion(response_model, params, config)
else
Expand Down
3 changes: 2 additions & 1 deletion lib/instructor/adapter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ defmodule Instructor.Adapter do
@moduledoc """
Behavior for `Instructor.Adapter`.
"""
@callback chat_completion([Keyword.t()], any()) :: any()
@callback chat_completion(map(), [Keyword.t()], any()) :: any()
@callback prompt(Keyword.t()) :: map()
end
38 changes: 20 additions & 18 deletions lib/instructor/adapters/llamacpp.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,39 @@ defmodule Instructor.Adapters.Llamacpp do
...> )
"""
@impl true
def chat_completion(params, _config \\ nil) do
def chat_completion(prompt, params, _config \\ nil) do
stream = Keyword.get(params, :stream, false)

if stream do
do_streaming_chat_completion(prompt)
else
do_chat_completion(prompt)
end
end

@impl true
def prompt(params) do
{response_model, _} = Keyword.pop!(params, :response_model)
{messages, _} = Keyword.pop!(params, :messages)

json_schema = JSONSchema.from_ecto_schema(response_model)
grammar = GBNF.from_json_schema(json_schema)
prompt = apply_chat_template(chat_template(), messages)
stream = Keyword.get(params, :stream, false)

if stream do
do_streaming_chat_completion(prompt, grammar)
else
do_chat_completion(prompt, grammar)
end
%{
grammar: grammar,
prompt: prompt
}
end

defp do_streaming_chat_completion(prompt, grammar) do
defp do_streaming_chat_completion(prompt) do
pid = self()

Stream.resource(
fn ->
Task.async(fn ->
Req.post(url(),
json: %{
grammar: grammar,
prompt: prompt,
stream: true
},
json: Map.put(prompt, :stream, true),
receive_timeout: 60_000,
into: fn {:data, data}, {req, resp} ->
send(pid, data)
Expand Down Expand Up @@ -94,13 +99,10 @@ defmodule Instructor.Adapters.Llamacpp do
}
end

defp do_chat_completion(prompt, grammar) do
defp do_chat_completion(prompt) do
response =
Req.post(url(),
json: %{
grammar: grammar,
prompt: prompt
},
json: prompt,
receive_timeout: 60_000
)

Expand Down
29 changes: 17 additions & 12 deletions lib/instructor/adapters/openai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,30 @@ defmodule Instructor.Adapters.OpenAI do
@behaviour Instructor.Adapter

@impl true
def chat_completion(params, config) do
def chat_completion(prompt, params, config) do
config = if config, do: config, else: config()

stream = Keyword.get(params, :stream, false)

if stream do
do_streaming_chat_completion(prompt, config)
else
do_chat_completion(prompt, config)
end
end

@impl true
def prompt(params) do
# Peel off instructor only parameters
{_, params} = Keyword.pop(params, :response_model)
{_, params} = Keyword.pop(params, :validation_context)
{_, params} = Keyword.pop(params, :max_retries)
{_, params} = Keyword.pop(params, :mode)
stream = Keyword.get(params, :stream, false)
params = Enum.into(params, %{})

if stream do
do_streaming_chat_completion(params, config)
else
do_chat_completion(params, config)
end
Enum.into(params, %{})
end

defp do_streaming_chat_completion(params, config) do
defp do_streaming_chat_completion(prompt, config) do
pid = self()
options = http_options(config)

Expand All @@ -32,7 +37,7 @@ defmodule Instructor.Adapters.OpenAI do
Task.async(fn ->
options =
Keyword.merge(options,
json: params,
json: prompt,
auth: {:bearer, api_key(config)},
into: fn {:data, data}, {req, resp} ->
chunks =
Expand Down Expand Up @@ -75,8 +80,8 @@ defmodule Instructor.Adapters.OpenAI do
)
end

defp do_chat_completion(params, config) do
options = Keyword.merge(http_options(config), json: params, auth: {:bearer, api_key(config)})
defp do_chat_completion(prompt, config) do
options = Keyword.merge(http_options(config), json: prompt, auth: {:bearer, api_key(config)})

case Req.post(url(config), options) do
{:ok, %{status: 200, body: body}} -> {:ok, body}
Expand Down
Loading