Skip to content

Commit abf39b8

Browse files
committed
Use the JsonSchema conversion methods in models
1 parent 1fb6bf1 commit abf39b8

File tree

6 files changed

+89
-123
lines changed

6 files changed

+89
-123
lines changed

outlines/models/dottxt.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
11
"""Integration with Dottxt's API."""
22

3-
import json
4-
from typing import TYPE_CHECKING, Any, Optional
5-
6-
from pydantic import TypeAdapter
3+
from typing import TYPE_CHECKING, Any, Optional, cast
74

85
from outlines.models.base import Model, ModelTypeAdapter
96
from outlines.types import CFG, JsonSchema, Regex
10-
from outlines.types.utils import (
11-
is_dataclass,
12-
is_genson_schema_builder,
13-
is_pydantic_model,
14-
is_typed_dict,
15-
)
167

178
if TYPE_CHECKING:
189
from dottxt import Dottxt as DottxtClient
@@ -77,20 +68,8 @@ def format_output_type(self, output_type: Optional[Any] = None) -> str:
7768
"CFG-based structured outputs will soon be available with "
7869
"Dottxt. Use an open source model in the meantime."
7970
)
80-
81-
elif isinstance(output_type, JsonSchema):
82-
return output_type.schema
83-
elif is_dataclass(output_type):
84-
schema = TypeAdapter(output_type).json_schema()
85-
return json.dumps(schema)
86-
elif is_typed_dict(output_type):
87-
schema = TypeAdapter(output_type).json_schema()
88-
return json.dumps(schema)
89-
elif is_pydantic_model(output_type):
90-
schema = output_type.model_json_schema()
91-
return json.dumps(schema)
92-
elif is_genson_schema_builder(output_type):
93-
return output_type.to_json()
71+
elif JsonSchema.is_json_schema(output_type):
72+
return cast(str, JsonSchema.convert_to(output_type, ["str"]))
9473
else:
9574
type_name = getattr(output_type, "__name__", output_type)
9675
raise TypeError(

outlines/models/gemini.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@
1414
from outlines.models.base import Model, ModelTypeAdapter
1515
from outlines.types import CFG, Choice, JsonSchema, Regex
1616
from outlines.types.utils import (
17-
is_dataclass,
1817
is_enum,
1918
get_enum_from_choice,
2019
get_enum_from_literal,
2120
is_genson_schema_builder,
2221
is_literal,
23-
is_pydantic_model,
24-
is_typed_dict,
2522
is_typing_list,
2623
)
2724

@@ -171,28 +168,18 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
171168
"CFG-based structured outputs are not available with Gemini. "
172169
"Use an open source model or dottxt instead."
173170
)
174-
elif is_genson_schema_builder(output_type):
175-
raise TypeError(
176-
"The Gemini SDK does not accept Genson schema builders as an "
177-
"input. Pass a Pydantic model, typed dict or dataclass "
178-
"instead."
179-
)
180-
elif isinstance(output_type, JsonSchema):
181-
raise TypeError(
182-
"The Gemini SDK does not accept Json Schemas as an input. "
183-
"Pass a Pydantic model, typed dict or dataclass instead."
184-
)
185171

186172
if output_type is None:
187173
return {}
188174

189-
# Structured types
190-
elif is_dataclass(output_type):
191-
return self.format_json_output_type(output_type)
192-
elif is_typed_dict(output_type):
193-
return self.format_json_output_type(output_type)
194-
elif is_pydantic_model(output_type):
195-
return self.format_json_output_type(output_type)
175+
# JSON schema types
176+
elif JsonSchema.is_json_schema(output_type):
177+
return self.format_json_output_type(
178+
JsonSchema.convert_to(
179+
output_type,
180+
["dataclass", "typeddict", "pydantic"]
181+
)
182+
)
196183

197184
# List of structured types
198185
elif is_typing_list(output_type):
@@ -233,21 +220,20 @@ def format_list_output_type(self, output_type: Optional[Any]) -> dict:
233220
if len(args) == 1:
234221
item_type = args[0]
235222

236-
# Check if list item type is supported
237-
if (
238-
is_pydantic_model(item_type)
239-
or is_typed_dict(item_type)
240-
or is_dataclass(item_type)
241-
):
223+
if JsonSchema.is_json_schema(item_type):
242224
return {
243225
"response_mime_type": "application/json",
244-
"response_schema": output_type,
226+
"response_schema": list[ # type: ignore
227+
JsonSchema.convert_to(
228+
item_type,
229+
["dataclass", "typeddict", "pydantic"]
230+
)
231+
],
245232
}
246-
247233
else:
248234
raise TypeError(
249-
"The only supported types for list items are Pydantic "
250-
+ "models, typed dicts and dataclasses."
235+
"The list items output type must contain a JSON schema "
236+
"type."
251237
)
252238

253239
raise TypeError(

outlines/models/ollama.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
"""Integration with the `ollama` library."""
22

3-
import json
43
from functools import singledispatchmethod
5-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union
6-
7-
from pydantic import TypeAdapter
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
AsyncIterator,
8+
Iterator,
9+
Optional,
10+
Union,
11+
cast,
12+
)
813

914
from outlines.inputs import Chat, Image
1015
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
1116
from outlines.types import CFG, JsonSchema, Regex
12-
from outlines.types.utils import (
13-
is_dataclass,
14-
is_genson_schema_builder,
15-
is_pydantic_model,
16-
is_typed_dict,
17-
)
1817

1918
if TYPE_CHECKING:
2019
from ollama import Client
@@ -109,7 +108,7 @@ def _create_message(self, role: str, content: str | list) -> dict:
109108

110109
def format_output_type(
111110
self, output_type: Optional[Any] = None
112-
) -> Optional[str]:
111+
) -> Optional[dict]:
113112
"""Format the output type to pass to the client.
114113
115114
TODO: `int`, `float` and other Python types could be supported via
@@ -126,7 +125,9 @@ def format_output_type(
126125
The formatted output type to be passed to the model.
127126
128127
"""
129-
if isinstance(output_type, Regex):
128+
if output_type is None:
129+
return None
130+
elif isinstance(output_type, Regex):
130131
raise TypeError(
131132
"Regex-based structured outputs are not supported by Ollama. "
132133
"Use an open source model in the meantime."
@@ -136,22 +137,8 @@ def format_output_type(
136137
"CFG-based structured outputs are not supported by Ollama. "
137138
"Use an open source model in the meantime."
138139
)
139-
140-
if output_type is None:
141-
return None
142-
elif isinstance(output_type, JsonSchema):
143-
return json.loads(output_type.schema)
144-
elif is_dataclass(output_type):
145-
schema = TypeAdapter(output_type).json_schema()
146-
return schema
147-
elif is_typed_dict(output_type):
148-
schema = TypeAdapter(output_type).json_schema()
149-
return schema
150-
elif is_pydantic_model(output_type):
151-
schema = output_type.model_json_schema()
152-
return schema
153-
elif is_genson_schema_builder(output_type):
154-
return output_type.to_json()
140+
elif JsonSchema.is_json_schema(output_type):
141+
return cast(dict, JsonSchema.convert_to(output_type, ["dict"]))
155142
else:
156143
type_name = getattr(output_type, "__name__", output_type)
157144
raise TypeError(

outlines/models/openai.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,23 @@
11
"""Integration with OpenAI's API."""
22

3-
import json
43
from typing import (
54
TYPE_CHECKING,
65
Any,
76
AsyncIterator,
87
Iterator,
98
Optional,
109
Union,
10+
cast,
1111
)
1212
from functools import singledispatchmethod
1313

14-
from pydantic import BaseModel, TypeAdapter
14+
from pydantic import BaseModel
1515

1616
from outlines.inputs import Chat, Image
1717
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
1818
from outlines.models.utils import set_additional_properties_false_json_schema
1919
from outlines.types import JsonSchema, Regex, CFG
20-
from outlines.types.utils import (
21-
is_dataclass,
22-
is_typed_dict,
23-
is_pydantic_model,
24-
is_genson_schema_builder,
25-
is_native_dict
26-
)
20+
from outlines.types.utils import is_native_dict
2721

2822
if TYPE_CHECKING:
2923
from openai import (
@@ -176,20 +170,10 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
176170
return {}
177171
elif is_native_dict(output_type):
178172
return self.format_json_mode_type()
179-
elif is_dataclass(output_type):
180-
output_type = TypeAdapter(output_type).json_schema()
181-
return self.format_json_output_type(output_type)
182-
elif is_typed_dict(output_type):
183-
output_type = TypeAdapter(output_type).json_schema()
184-
return self.format_json_output_type(output_type)
185-
elif is_pydantic_model(output_type):
186-
output_type = output_type.model_json_schema()
187-
return self.format_json_output_type(output_type)
188-
elif is_genson_schema_builder(output_type):
189-
schema = json.loads(output_type.to_json())
190-
return self.format_json_output_type(schema)
191-
elif isinstance(output_type, JsonSchema):
192-
return self.format_json_output_type(json.loads(output_type.schema))
173+
elif JsonSchema.is_json_schema(output_type):
174+
return self.format_json_output_type(
175+
cast(dict, JsonSchema.convert_to(output_type, ["dict"]))
176+
)
193177
else:
194178
type_name = getattr(output_type, "__name__", output_type)
195179
raise TypeError(

tests/models/test_gemini_type_adapter.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
import sys
44
from dataclasses import dataclass
55
from enum import Enum, EnumMeta
6-
from typing import Literal
6+
from typing import Literal, get_args
77

88
from PIL import Image as PILImage
99
from genson import SchemaBuilder
10-
from google.genai import types
1110
from pydantic import BaseModel
1211

1312
from outlines import cfg, json_schema, regex
1413
from outlines.inputs import Chat, Image
1514
from outlines.models.gemini import GeminiTypeAdapter
15+
from outlines.types.utils import is_dataclass
1616

1717
if sys.version_info >= (3, 12):
1818
from typing import TypedDict
@@ -135,19 +135,29 @@ def test_gemini_type_adapter_output_invalid(adapter):
135135
with pytest.raises(TypeError, match="CFG-based structured outputs"):
136136
adapter.format_output_type(cfg(""))
137137

138-
with pytest.raises(TypeError, match="The Gemini SDK does not accept"):
139-
adapter.format_output_type(SchemaBuilder())
140-
141-
with pytest.raises(TypeError, match="The Gemini SDK does not"):
142-
adapter.format_output_type(json_schema(""))
143-
144138

145139
def test_gemini_type_adapter_output_none(adapter):
146140
result = adapter.format_output_type(None)
147141
assert result == {}
148142

149143

150-
def test_gemini_type_adapter_output_dataclass(adapter, schema):
144+
def test_gemini_type_adapter_output_json_schema(adapter, schema):
145+
result = adapter.format_output_type(json_schema(schema))
146+
assert isinstance(result, dict)
147+
assert result["response_mime_type"] == "application/json"
148+
assert is_dataclass(result["response_schema"])
149+
150+
151+
def test_gemini_type_adapter_output_list_json_schema(adapter, schema):
152+
result = adapter.format_output_type(list[json_schema(schema)])
153+
assert isinstance(result, dict)
154+
assert result["response_mime_type"] == "application/json"
155+
args = get_args(result["response_schema"])
156+
assert len(args) == 1
157+
assert is_dataclass(args[0])
158+
159+
160+
def test_gemini_type_adapter_output_dataclass(adapter):
151161
@dataclass
152162
class User:
153163
user_id: int
@@ -160,7 +170,7 @@ class User:
160170
}
161171

162172

163-
def test_gemini_type_adapter_output_list_dataclass(adapter, schema):
173+
def test_gemini_type_adapter_output_list_dataclass(adapter):
164174
class User(BaseModel):
165175
user_id: int
166176
name: str
@@ -172,7 +182,7 @@ class User(BaseModel):
172182
}
173183

174184

175-
def test_gemini_type_adapter_output_typed_dict(adapter, schema):
185+
def test_gemini_type_adapter_output_typed_dict(adapter):
176186
class User(TypedDict):
177187
user_id: int
178188
name: str
@@ -184,7 +194,7 @@ class User(TypedDict):
184194
}
185195

186196

187-
def test_gemini_type_adapter_output_list_typed_dict(adapter, schema):
197+
def test_gemini_type_adapter_output_list_typed_dict(adapter):
188198
class User(BaseModel):
189199
user_id: int
190200
name: str
@@ -196,7 +206,7 @@ class User(BaseModel):
196206
}
197207

198208

199-
def test_gemini_type_adapter_output_pydantic(adapter, schema):
209+
def test_gemini_type_adapter_output_pydantic(adapter):
200210
class User(BaseModel):
201211
user_id: int
202212
name: str
@@ -208,7 +218,7 @@ class User(BaseModel):
208218
}
209219

210220

211-
def test_gemini_type_adapter_output_list_pydantic(adapter, schema):
221+
def test_gemini_type_adapter_output_list_pydantic(adapter):
212222
class User(BaseModel):
213223
user_id: int
214224
name: str
@@ -220,6 +230,26 @@ class User(BaseModel):
220230
}
221231

222232

233+
def test_gemini_type_adapter_output_genson_schema_builder(adapter):
234+
builder = SchemaBuilder()
235+
builder.add_schema({"type": "object", "properties": {"foo": {"type": "string"}, "bar": {"type": "integer"}}, "required": ["foo"]})
236+
result = adapter.format_output_type(builder)
237+
assert isinstance(result, dict)
238+
assert result["response_mime_type"] == "application/json"
239+
assert is_dataclass(result["response_schema"])
240+
241+
242+
def test_gemini_type_adapter_output_list_genson_schema_builder(adapter):
243+
builder = SchemaBuilder()
244+
builder.add_schema({"type": "object", "properties": {"foo": {"type": "string"}, "bar": {"type": "integer"}}, "required": ["foo"]})
245+
result = adapter.format_output_type(list[builder])
246+
assert isinstance(result, dict)
247+
assert result["response_mime_type"] == "application/json"
248+
args = get_args(result["response_schema"])
249+
assert len(args) == 1
250+
assert is_dataclass(args[0])
251+
252+
223253
def test_gemini_type_adapter_output_enum(adapter):
224254
class Foo(Enum):
225255
Bar = "bar"

0 commit comments

Comments
 (0)