Skip to content

Commit

Permalink
Simplify langhcain config (#282)
Browse files Browse the repository at this point in the history
* Made it possible to have config with langchain

* Version bump
  • Loading branch information
whitead committed Jun 6, 2024
1 parent 334b01e commit b7a3d68
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
19 changes: 7 additions & 12 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
sys_message = next(
(m["content"] for m in messages if m["role"] == "system"), None
)
# BECAUISE THEY DO NOT USE NONE TO INDICATE SENTINEL
# BECAUSE THEY DO NOT USE NONE TO INDICATE SENTINEL
# LIKE ANY SANE PERSON
if sys_message:
completion = await aclient.messages.create(
Expand Down Expand Up @@ -675,15 +675,10 @@ async def similarity_search(
)


# All the langchain stuff is below
# Many confusing woes here because langchain
# is not serializable and so we have to
# do some gymnastics to make it work


class LangchainLLMModel(LLMModel):
"""A wrapper around the wrapper langchain."""

config: dict = Field(default={"temperature": 0.1})
name: str = "langchain"

def infer_llm_type(self, client: Any) -> str:
Expand All @@ -695,10 +690,10 @@ def infer_llm_type(self, client: Any) -> str:
return "completion"

async def acomplete(self, client: Any, prompt: str) -> str:
return await client.ainvoke(prompt)
return await client.ainvoke(prompt, **self.config)

async def acomplete_iter(self, client: Any, prompt: str) -> Any:
async for chunk in cast(AsyncGenerator, client.astream(prompt)):
async for chunk in cast(AsyncGenerator, client.astream(prompt, **self.config)):
yield chunk

async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
Expand All @@ -712,7 +707,7 @@ async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
lc_messages.append(SystemMessage(content=m["content"]))
else:
raise ValueError(f"Unknown role: {m['role']}")
return (await client.ainvoke(lc_messages)).content
return (await client.ainvoke(lc_messages, **self.config)).content

async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
Expand All @@ -725,7 +720,7 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
lc_messages.append(SystemMessage(content=m["content"]))
else:
raise ValueError(f"Unknown role: {m['role']}")
async for chunk in client.astream(lc_messages):
async for chunk in client.astream(lc_messages, **self.config):
yield chunk.content


Expand Down Expand Up @@ -871,7 +866,7 @@ def llm_model_factory(llm: str) -> LLMModel:
if llm != "default":
if is_openai_model(llm):
return OpenAILLMModel(config={"model": llm})
elif llm == "langchain": # noqa: RET505
elif llm.startswith("langchain"): # noqa: RET505
return LangchainLLMModel()
elif "claude" in llm:
return AnthropicLLMModel(config={"model": llm})
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ name = "paper-qa"
readme = "README.md"
requires-python = ">=3.8"
urls = {repository = "https://github.com/whitead/paper-qa"}
version = "4.6.1"
version = "4.7.0"

[tool.codespell]
check-filenames = true
Expand Down

0 comments on commit b7a3d68

Please sign in to comment.