Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add google serper (search) tool #221

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,677 changes: 1,156 additions & 521 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions crates/llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ fn model_params_from_options(opts: OptionsCascade) -> Result<ModelParameters, ()

fn inference_params_from_options(opts: OptionsCascade) -> Result<InferenceParameters, ()> {
let Some(Opt::NThreads(n_threads)) = opts.get(OptDiscriminants::NThreads) else {
return Err(())
return Err(());
};
let Some(Opt::NBatch(n_batch)) = opts.get(OptDiscriminants::NBatch) else {
return Err(())
return Err(());
};
let Some(Opt::TopK(top_k)) = opts.get(OptDiscriminants::TopK) else {
return Err(())
return Err(());
};
let Some(Opt::TopP(top_p)) = opts.get(OptDiscriminants::TopP) else {
return Err(());
Expand All @@ -225,7 +225,9 @@ fn inference_params_from_options(opts: OptionsCascade) -> Result<InferenceParame
return Err(());
};

let Some(Opt::RepeatPenaltyLastN(repetition_penalty_last_n)) = opts.get(OptDiscriminants::RepeatPenaltyLastN) else {
let Some(Opt::RepeatPenaltyLastN(repetition_penalty_last_n)) =
opts.get(OptDiscriminants::RepeatPenaltyLastN)
else {
return Err(());
};

Expand Down
3 changes: 2 additions & 1 deletion crates/llm-chain-milvus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "llm-chain-milvus"
version = "0.1.0"
edition = "2021"
license = "MIT"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -19,4 +20,4 @@ milvus-sdk-rust = "0.1.0"
llm-chain-openai = { path = "../llm-chain-openai" }
tokio = "1.28.2"
serde_yaml = "0.9.21"
rand = "0.8.5"
rand = "0.8.5"
32 changes: 32 additions & 0 deletions crates/llm-chain-openai/examples/self_ask_with_google_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use llm_chain::{
agents::self_ask_with_search::{Agent, EarlyStoppingConfig},
executor,
tools::tools::GoogleSerper,
};

#[tokio::main(flavor = "current_thread")]
async fn main() {
let executor = executor!().unwrap();
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
let search_tool = GoogleSerper::new(serper_api_key);
let agent = Agent::new(
executor,
search_tool,
EarlyStoppingConfig {
max_iterations: Some(10),
max_time_elapsed_seconds: Some(30.0),
},
);
let (res, intermediate_steps) = agent
.run("What is the capital of the birthplace of Levy Mwanawasa?")
.await
.unwrap();
println!(
"Are followup questions needed here: {}",
agent.build_agent_scratchpad(&intermediate_steps)
);
println!(
"Agent final answer: {}",
res.return_values.get("output").unwrap()
);
}
2 changes: 1 addition & 1 deletion crates/llm-chain-openai/src/chatgpt/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Executor {

fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
return "gpt-3.5-turbo".to_string()
return "gpt-3.5-turbo".to_string();
};
model.to_name()
}
Expand Down
12 changes: 12 additions & 0 deletions crates/llm-chain/examples/google_serper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use llm_chain::tools::{tools::GoogleSerper, Tool};

#[tokio::main(flavor = "current_thread")]
async fn main() {
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
let serper = GoogleSerper::new(serper_api_key);
let result = serper
.invoke_typed(&"Who was the inventor of Catan?".into())
.await
.unwrap();
println!("Best answer from Google Serper: {}", result.result);
}
2 changes: 1 addition & 1 deletion crates/llm-chain/src/chains/map_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl Chain {
let mut current_doc = current.extract_last_body().cloned().unwrap_or_default();
while let Some(next) = v.last() {
let Some(next_doc_content) = next.extract_last_body() else {
continue
continue;
};
let mut new_doc = current_doc.clone();
new_doc.push('\n');
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-chain/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl<'a> OptionsCascade<'a> {
/// Returns a boolean indicating if options indicate that requests should be streamed or not.
pub fn is_streaming(&self) -> bool {
let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
return false
return false;
};
*val
}
Expand Down
137 changes: 137 additions & 0 deletions crates/llm-chain/src/tools/tools/google_serper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use async_trait::async_trait;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::tools::{Describe, Tool, ToolDescription, ToolError};

pub struct GoogleSerper {
api_key: String,
}

impl GoogleSerper {
pub fn new(api_key: String) -> Self {
Self { api_key }
}
}

#[derive(Serialize, Deserialize)]
pub struct GoogleSerperInput {
pub query: String,
}

impl From<&str> for GoogleSerperInput {
fn from(value: &str) -> Self {
Self {
query: value.into(),
}
}
}

impl From<String> for GoogleSerperInput {
fn from(value: String) -> Self {
Self { query: value }
}
}

impl Describe for GoogleSerperInput {
fn describe() -> crate::tools::Format {
vec![("query", "Search query to find necessary information").into()].into()
}
}

#[derive(Serialize, Deserialize)]
pub struct GoogleSerperOutput {
pub result: String,
}

impl From<String> for GoogleSerperOutput {
fn from(value: String) -> Self {
Self { result: value }
}
}

impl From<GoogleSerperOutput> for String {
fn from(val: GoogleSerperOutput) -> Self {
val.result
}
}

impl Describe for GoogleSerperOutput {
fn describe() -> crate::tools::Format {
vec![(
"result",
"Information retrieved from the internet that should answer your query",
)
.into()]
.into()
}
}

#[derive(Debug, Serialize, Deserialize)]
struct SiteLinks {
title: String,
link: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Organic {
title: String,
link: String,
snippet: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct GoogleSerperResult {
organic: Vec<Organic>,
}

#[derive(Debug, Error)]
pub enum GoogleSerperError {
#[error("No search results were returned")]
NoResults,
#[error(transparent)]
Yaml(#[from] serde_yaml::Error),
#[error(transparent)]
Request(#[from] reqwest::Error),
}

impl ToolError for GoogleSerperError {}

#[async_trait]
impl Tool for GoogleSerper {
type Input = GoogleSerperInput;

type Output = GoogleSerperOutput;

type Error = GoogleSerperError;

async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
let client = reqwest::Client::new();
let response = client
.request(Method::GET, "https://google.serper.dev/search")
.query(&[("q", &input.query)])
.header("X-API-KEY", self.api_key.clone())
.send()
.await?
.json::<GoogleSerperResult>()
.await?;
let answer = response
.organic
.first()
.ok_or(GoogleSerperError::NoResults)?
.snippet
.clone();
Ok(answer.into())
}

fn description(&self) -> ToolDescription {
ToolDescription::new(
"Google search",
"Useful for when you need to answer questions about current events. Input should be a search query.",
"Use this to get information about current events.",
GoogleSerperInput::describe(),
GoogleSerperOutput::describe(),
)
}
}
2 changes: 2 additions & 0 deletions crates/llm-chain/src/tools/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
mod bash;
mod bing_search;
mod exit;
mod google_serper;
mod python;
mod vectorstore;
pub use bash::{BashTool, BashToolError, BashToolInput, BashToolOutput};
pub use bing_search::{BingSearch, BingSearchError, BingSearchInput, BingSearchOutput};
pub use exit::{ExitTool, ExitToolError, ExitToolInput, ExitToolOutput};
pub use google_serper::{GoogleSerper, GoogleSerperError, GoogleSerperInput, GoogleSerperOutput};
pub use python::{PythonTool, PythonToolError, PythonToolInput, PythonToolOutput};
pub use vectorstore::{
VectorStoreTool, VectorStoreToolError, VectorStoreToolInput, VectorStoreToolOutput,
Expand Down