Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfcl patch #821

Draft
wants to merge 3 commits into
base: canary
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/baml-lib/jinja/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ pub enum ChatMessagePart {
Audio(BamlMedia),
}

#[derive(Debug, PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone, Serialize)]
pub enum RenderedPrompt {
Completion(String),
Chat(Vec<RenderedChatMessage>),
Expand Down
8 changes: 4 additions & 4 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct RetryLLMResponse {
pub failed: Vec<LLMResponse>,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize)]
pub enum LLMResponse {
Success(LLMCompleteResponse),
LLMFailure(LLMErrorResponse),
Expand Down Expand Up @@ -70,7 +70,7 @@ impl LLMResponse {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize)]
pub struct LLMErrorResponse {
pub client: String,
pub model: Option<String>,
Expand All @@ -84,7 +84,7 @@ pub struct LLMErrorResponse {
pub code: ErrorCode,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize)]
pub enum ErrorCode {
InvalidAuthentication, // 401
NotSupported, // 403
Expand Down Expand Up @@ -135,7 +135,7 @@ impl ErrorCode {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize)]
pub struct LLMCompleteResponse {
pub client: String,
pub model: String,
Expand Down
71 changes: 33 additions & 38 deletions engine/language-client-codegen/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,45 @@ use internal_baml_core::{
use self::python_language_features::{PythonLanguageFeatures, ToPython};
use crate::dir_writer::FileCollector;

#[derive(askama::Template)]
#[template(path = "async_client.py.j2", escape = "none")]
struct AsyncPythonClient {
struct PythonClient {
funcs: Vec<PythonFunction>,
}

#[derive(askama::Template)]
#[template(path = "sync_client.py.j2", escape = "none")]
struct SyncPythonClient {
funcs: Vec<PythonFunction>,
}
macro_rules! impl_from_python_client {
($($target:ident => $template:expr),+) => {
$(
#[derive(askama::Template)]
#[template(path = $template, escape = "none")]
struct $target {
funcs: Vec<PythonFunction>,
}

struct PythonClient {
funcs: Vec<PythonFunction>,
}
impl From<PythonClient> for $target {
fn from(client: PythonClient) -> Self {
Self {
funcs: client.funcs,
}
}
}

impl From<PythonClient> for AsyncPythonClient {
fn from(value: PythonClient) -> Self {
Self { funcs: value.funcs }
}
}
impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for $target {
type Error = anyhow::Error;

impl From<PythonClient> for SyncPythonClient {
fn from(value: PythonClient) -> Self {
Self { funcs: value.funcs }
}
fn try_from(params: (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result<Self> {
let python_client = PythonClient::try_from(params)?;
Ok(python_client.into())
}
}
)+
};
}

impl_from_python_client!(
AsyncPythonClient => "async_client.py.j2",
SyncPythonClient => "sync_client.py.j2",
UnstableAsyncPythonClient => "unstable_async_client.py.j2"
);

struct PythonFunction {
name: String,
partial_return_type: String,
Expand Down Expand Up @@ -80,6 +91,8 @@ pub(crate) fn generate(
collector.add_template::<generate_types::PythonTypes>("types.py", (ir, generator))?;
collector.add_template::<generate_types::TypeBuilder>("type_builder.py", (ir, generator))?;
collector.add_template::<AsyncPythonClient>("async_client.py", (ir, generator))?;
collector
.add_template::<UnstableAsyncPythonClient>("unstable_async_client.py", (ir, generator))?;
collector.add_template::<SyncPythonClient>("sync_client.py", (ir, generator))?;
collector.add_template::<PythonGlobals>("globals.py", (ir, generator))?;
collector.add_template::<PythonTracing>("tracing.py", (ir, generator))?;
Expand Down Expand Up @@ -125,24 +138,6 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for InlinedBaml {
}
}

impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for AsyncPythonClient {
type Error = anyhow::Error;

fn try_from(params: (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result<Self> {
let python_client = PythonClient::try_from(params)?;
Ok(python_client.into())
}
}

impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for SyncPythonClient {
type Error = anyhow::Error;

fn try_from(params: (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result<Self> {
let python_client = PythonClient::try_from(params)?;
Ok(python_client.into())
}
}

impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient {
type Error = anyhow::Error;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class BamlAsyncClient:
self.__ctx_manager = ctx_manager
self.__stream_client = BamlStreamClient(self.__runtime, self.__ctx_manager)

def z_unstable_runtime(self) -> baml_py.BamlRuntime:
return self.__runtime

def z_unstable_ctx_manager(self) -> baml_py.BamlCtxManager:
return self.__ctx_manager

@property
def stream(self):
return self.__stream_client
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type
from typing_extensions import NotRequired
import pprint

import baml_py
from pydantic import BaseModel, ValidationError, create_model

from . import partial_types, types
from .type_builder import TypeBuilder
from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME


OutputType = TypeVar('OutputType')

def coerce(cls: Type[BaseModel], parsed: Any) -> Any:
try:
return cls.model_validate({"inner": parsed}).inner # type: ignore
except ValidationError as e:
raise TypeError(
"Internal BAML error while casting output to {}\n{}".format(
cls.__name__,
pprint.pformat(parsed)
)
) from e

# Define the TypedDict with optional parameters having default values
class BamlCallOptions(TypedDict, total=False):
tb: NotRequired[TypeBuilder]
client_registry: NotRequired[baml_py.baml_py.ClientRegistry]

class UnstableBamlAsyncClient:
__runtime: baml_py.BamlRuntime
__ctx_manager: baml_py.BamlCtxManager

def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager):
self.__runtime = runtime
self.__ctx_manager = ctx_manager

@property
def stream(self):
return self.__stream_client


{% for fn in funcs %}
async def {{ fn.name }}(
self,
{% for (name, type) in fn.args -%}
{{name}}: {{type}},
{%- endfor %}
baml_options: BamlCallOptions = {},
) -> (baml_py.baml_py.FunctionResult, Union[Tuple[Literal[True, {{fn.return_type}}]], Tuple[Literal[False, None]]]):
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"{{fn.name}}",
{
{% for (name, _) in fn.args -%}
"{{name}}": {{name}},
{%- endfor %}
},
self.__ctx_manager.get(),
tb,
__cr__,
)
if raw.is_ok() is None:
mdl = create_model("{{ fn.name }}ReturnType", inner=({{ fn.return_type }}, ...))
return raw, (True, coerce(mdl, raw.parsed()))
return raw, (False, None)
Comment on lines +70 to +73
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic: The condition if raw.is_ok() is None seems incorrect. It should likely be if raw.is_ok() to check if the result is successful.

{% endfor %}

b = UnstableBamlAsyncClient(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)

__all__ = ["b"]
3 changes: 2 additions & 1 deletion engine/language_client_python/python_src/baml_py/baml_py.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

class FunctionResult:
"""The result of a BAML function call.
Expand All @@ -18,6 +18,7 @@ class FunctionResult:
def parsed(self) -> Any: ...
# Returns True if the function call was successful, False otherwise
def is_ok(self) -> bool: ...
def internals(self) -> str: ...

class FunctionResultStream:
"""The result of a BAML function stream.
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_python/src/types/client_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl ClientRegistry {
options: PyObject,
retry_policy: Option<String>,
) -> PyResult<()> {
let Some(args) = parse_py_type(options.into_bound(py).to_object(py), false)? else {
let Some(args) = parse_py_type(options, false)? else {
return Err(BamlError::new_err(
"Failed to parse args, perhaps you used a non-serializable type?",
));
Expand Down
7 changes: 7 additions & 0 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use baml_runtime::internal::llm_client::LLMResponse;
use baml_types::BamlValue;
use pyo3::prelude::{pymethods, PyResult};
use pyo3::{PyObject, Python};
use pythonize::pythonize;
use serde_json::json;

crate::lang_wrapper!(FunctionResult, baml_runtime::FunctionResult);

Expand All @@ -23,4 +25,9 @@ impl FunctionResult {

Ok(pythonize(py, &BamlValue::from(parsed))?)
}

fn internals(&self) -> PyResult<String> {
let content = self.inner.llm_response().clone();
serde_json::to_string(&content).map_err(|e| crate::BamlError::new_err(e.to_string()))
}
}
1 change: 1 addition & 0 deletions engine/language_client_python/src/types/type_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::{
types::{PyTuple, PyTupleMethods},
Bound, PyResult,
};
use pyo3::{PyObject, Python, ToPyObject};

crate::lang_wrapper!(TypeBuilder, type_builder::TypeBuilder);
crate::lang_wrapper!(EnumBuilder, type_builder::EnumBuilder, sync_thread_safe, name: String);
Expand Down
6 changes: 6 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxMan
self.__ctx_manager = ctx_manager
self.__stream_client = BamlStreamClient(self.__runtime, self.__ctx_manager)

def z_unstable_runtime(self) -> baml_py.BamlRuntime:
return self.__runtime

def z_unstable_ctx_manager(self) -> baml_py.BamlCtxManager:
return self.__ctx_manager

@property
def stream(self):
return self.__stream_client
Expand Down
Loading