Skip to content

Commit

Permalink
Фикс функций и распределение кода по папкам (#233)
Browse files Browse the repository at this point in the history
* - Перенес парсеры в папку output_parsers
- Перенес преобразователи функций в utils/function_calling.py
- Изменил парсеры, добавив поддержку pydantic v2 методов
- Пофиксил преобразование тулов, сделав кастомный генератор JSON схемы `pydantic_generator.py`, также добавил проверку на тулы с Union[X,Y] -> теперь кидается exception при встрече такого типа, так как гигачат такое не поддерживает (убрал anyOf с Optional полями, сделал их добавления в required)
- Добавил кастомные классы GigaBaseTool и метот giga_tool, позволяющий добавлять return_schema и few_shot_examples
- Добавил тесты на преобразования стандартных тулов (нужно будет сделать тесты на giga_tool)
- Пофиксил метод with_structured_output. Возможно нужно будет убрать json_mode

* Пофиксил линтер

* Пофиксил линтер
  • Loading branch information
Mikelarg authored Oct 17, 2024
1 parent 1ed1a44 commit 9e790b6
Show file tree
Hide file tree
Showing 10 changed files with 894 additions and 297 deletions.
48 changes: 30 additions & 18 deletions libs/langchain_gigachat/langchain_gigachat/chat_models/gigachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
JsonOutputParser,
PydanticOutputParser,
PydanticToolsParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -60,9 +66,7 @@
from pydantic import BaseModel

from langchain_gigachat.chat_models.base_gigachat import _BaseGigaChat
from langchain_gigachat.tools.gigachat_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
from langchain_gigachat.utils.function_calling import (
convert_to_gigachat_function,
convert_to_gigachat_tool,
)
Expand All @@ -76,7 +80,7 @@
r'<img\ssrc="(?P<UUID>.+?)"\sfuse=".+?"/>(?P<postfix>.+)?'
)
VIDEO_SEARCH_REGEX = re.compile(
r'<video\scover="(?P<cover_UUID>.+?)"\ssrc="(?P<UUID>.+?)"\sfuse="true"/>(?P<postfix>.+)?'
r'<video\scover="(?P<cover_UUID>.+?)"\ssrc="(?P<UUID>.+?)"\sfuse="true"/>(?P<postfix>.+)?' # noqa
)


Expand Down Expand Up @@ -588,22 +592,30 @@ def with_structured_output(
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
key_name = convert_to_gigachat_tool(schema)["function"]["name"]
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore
first_tool_only=True,
)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
key_name = convert_to_gigachat_tool(schema)["function"]["name"]
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
llm = self.bind_tools([schema], tool_choice=key_name)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
llm = self
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
llm = self.bind_tools([schema], tool_choice=key_name)

if include_raw:
parser_assign = RunnablePassthrough.assign(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import copy
from types import GenericAlias
from typing import Any, Dict, List, Type, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from pydantic import BaseModel, model_validator


class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""

args_only: bool = True
"""Whether to only return the arguments to the function call."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc

if self.args_only:
return func_call["arguments"]
return func_call


class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object."""

pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""

@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: dict) -> Any:
"""Validate the pydantic schema.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If the schema is not a pydantic schema.
"""
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = (
isinstance(schema, type)
and not isinstance(schema, GenericAlias)
and issubclass(schema, BaseModel)
)
elif values["args_only"] and isinstance(schema, dict):
msg = (
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
raise ValueError(msg)
return values

def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
_result = super().parse_result(result)
if self.args_only:
if hasattr(self.pydantic_schema, "model_validate"):
pydantic_args = self.pydantic_schema.model_validate(_result)
else:
pydantic_args = self.pydantic_schema.parse_obj(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
if isinstance(self.pydantic_schema, dict):
pydantic_schema = self.pydantic_schema[fn_name]
else:
pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate"):
pydantic_args = pydantic_schema.model_validate(_args) # type: ignore
else:
pydantic_args = pydantic_schema.parse_obj(_args) # type: ignore
return pydantic_args


class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""

attr_name: str
"""The name of the attribute to return."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)
Loading

0 comments on commit 9e790b6

Please sign in to comment.