-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcore.py
149 lines (107 loc) · 4.15 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import contextlib
import dataclasses
import enum
import types
import typing
import annotated_types
import pydantic
import typing_extensions
class MessageRole(str, enum.Enum):
system = "system"
user = "user"
assistant = "assistant"
@pydantic.dataclasses.dataclass
class TextContentItem:
text: str
@pydantic.dataclasses.dataclass
class ImageContentItem:
image_url: str
"""
HTTP image url or data url in following format:
data:image/jpeg;base64,{base64.b64encode(jpeg_image_bytes).decode('utf-8')}
"""
AnyContentItem = TextContentItem | ImageContentItem
ContentItemList = typing.Annotated[list[AnyContentItem], annotated_types.MinLen(1)]
@pydantic.dataclasses.dataclass(kw_only=True)
class Message:
role: MessageRole
content: str | ContentItemList
@pydantic.dataclasses.dataclass
class LLMResponse:
content: str | None = None
reasoning_content: str | None = None
if typing.TYPE_CHECKING:
@pydantic.dataclasses.dataclass
class SystemMessage(Message):
role: typing.Literal[MessageRole.system] = pydantic.Field(MessageRole.system, init=False)
content: str | ContentItemList
@pydantic.dataclasses.dataclass
class UserMessage(Message):
role: typing.Literal[MessageRole.user] = pydantic.Field(MessageRole.user, init=False)
content: str | ContentItemList
@pydantic.dataclasses.dataclass
class AssistantMessage(Message):
role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False)
content: str | ContentItemList
else:
def SystemMessage(content: str | ContentItemList) -> Message: # noqa: N802
return Message(role=MessageRole.system, content=content)
def UserMessage(content: str | ContentItemList) -> Message: # noqa: N802
return Message(role=MessageRole.user, content=content)
def AssistantMessage(content: str | ContentItemList) -> Message: # noqa: N802
return Message(role=MessageRole.assistant, content=content)
class LLMConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(protected_namespaces=())
api_type: str
temperature: float = 0.2
request_extra: dict[str, typing.Any] = pydantic.Field(default_factory=dict)
def _resolve_request_temperature(self, temperature_arg_value: float) -> float:
return (
self.temperature
if isinstance(temperature_arg_value, LLMConfigValue) # type: ignore[arg-type]
else temperature_arg_value
)
if typing.TYPE_CHECKING:
def LLMConfigValue(*, attr: str) -> typing.Any: # noqa: ANN401, N802
"""Defaults to value from LLMConfig."""
else:
@dataclasses.dataclass(kw_only=True, frozen=True, slots=True)
class LLMConfigValue:
"""Defaults to value from LLMConfig."""
attr: str
@dataclasses.dataclass(slots=True, init=False)
class LLMClient(typing.Protocol):
async def request_llm_message(
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> LLMResponse: ... # raises LLMError, LLMRequestValidationError
@contextlib.asynccontextmanager
def stream_llm_message_chunks(
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: ... # raises LLMError, LLMRequestValidationError
async def __aenter__(self) -> typing_extensions.Self: ...
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None: ...
@dataclasses.dataclass
class AnyLLMClientError(Exception):
def __str__(self) -> str:
return self.__repr__().removeprefix(self.__class__.__name__)
@dataclasses.dataclass
class LLMError(AnyLLMClientError):
response_content: bytes
@dataclasses.dataclass
class OutOfTokensOrSymbolsError(LLMError): ...
@dataclasses.dataclass
class LLMRequestValidationError(AnyLLMClientError):
message: str