Skip to content

Commit

Permalink
Merge pull request #221 from danbev/google-serper
Browse files Browse the repository at this point in the history
feat: add google serper (search) tool
  • Loading branch information
williamhogman authored Oct 19, 2023
2 parents fff9b19 + 1f07014 commit c423f22
Show file tree
Hide file tree
Showing 10 changed files with 1,350 additions and 529 deletions.
1,677 changes: 1,156 additions & 521 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions crates/llm-chain-local/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ fn model_params_from_options(opts: OptionsCascade) -> Result<ModelParameters, ()

fn inference_params_from_options(opts: OptionsCascade) -> Result<InferenceParameters, ()> {
let Some(Opt::NThreads(n_threads)) = opts.get(OptDiscriminants::NThreads) else {
return Err(())
return Err(());
};
let Some(Opt::NBatch(n_batch)) = opts.get(OptDiscriminants::NBatch) else {
return Err(())
return Err(());
};
let Some(Opt::TopK(top_k)) = opts.get(OptDiscriminants::TopK) else {
return Err(())
return Err(());
};
let Some(Opt::TopP(top_p)) = opts.get(OptDiscriminants::TopP) else {
return Err(());
Expand All @@ -225,7 +225,9 @@ fn inference_params_from_options(opts: OptionsCascade) -> Result<InferenceParame
return Err(());
};

let Some(Opt::RepeatPenaltyLastN(repetition_penalty_last_n)) = opts.get(OptDiscriminants::RepeatPenaltyLastN) else {
let Some(Opt::RepeatPenaltyLastN(repetition_penalty_last_n)) =
opts.get(OptDiscriminants::RepeatPenaltyLastN)
else {
return Err(());
};

Expand Down
3 changes: 2 additions & 1 deletion crates/llm-chain-milvus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "llm-chain-milvus"
version = "0.1.0"
edition = "2021"
license = "MIT"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -19,4 +20,4 @@ milvus-sdk-rust = "0.1.0"
llm-chain-openai = { path = "../llm-chain-openai" }
tokio = "1.28.2"
serde_yaml = "0.9.21"
rand = "0.8.5"
rand = "0.8.5"
32 changes: 32 additions & 0 deletions crates/llm-chain-openai/examples/self_ask_with_google_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use llm_chain::{
agents::self_ask_with_search::{Agent, EarlyStoppingConfig},
executor,
tools::tools::GoogleSerper,
};

#[tokio::main(flavor = "current_thread")]
async fn main() {
let executor = executor!().unwrap();
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
let search_tool = GoogleSerper::new(serper_api_key);
let agent = Agent::new(
executor,
search_tool,
EarlyStoppingConfig {
max_iterations: Some(10),
max_time_elapsed_seconds: Some(30.0),
},
);
let (res, intermediate_steps) = agent
.run("What is the capital of the birthplace of Levy Mwanawasa?")
.await
.unwrap();
println!(
"Are followup questions needed here: {}",
agent.build_agent_scratchpad(&intermediate_steps)
);
println!(
"Agent final answer: {}",
res.return_values.get("output").unwrap()
);
}
2 changes: 1 addition & 1 deletion crates/llm-chain-openai/src/chatgpt/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Executor {

fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
return "gpt-3.5-turbo".to_string()
return "gpt-3.5-turbo".to_string();
};
model.to_name()
}
Expand Down
12 changes: 12 additions & 0 deletions crates/llm-chain/examples/google_serper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use llm_chain::tools::{tools::GoogleSerper, Tool};

#[tokio::main(flavor = "current_thread")]
async fn main() {
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap();
let serper = GoogleSerper::new(serper_api_key);
let result = serper
.invoke_typed(&"Who was the inventor of Catan?".into())
.await
.unwrap();
println!("Best answer from Google Serper: {}", result.result);
}
2 changes: 1 addition & 1 deletion crates/llm-chain/src/chains/map_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl Chain {
let mut current_doc = current.extract_last_body().cloned().unwrap_or_default();
while let Some(next) = v.last() {
let Some(next_doc_content) = next.extract_last_body() else {
continue
continue;
};
let mut new_doc = current_doc.clone();
new_doc.push('\n');
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-chain/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl<'a> OptionsCascade<'a> {
/// Returns a boolean indicating if options indicate that requests should be streamed or not.
pub fn is_streaming(&self) -> bool {
let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
return false
return false;
};
*val
}
Expand Down
137 changes: 137 additions & 0 deletions crates/llm-chain/src/tools/tools/google_serper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use async_trait::async_trait;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::tools::{Describe, Tool, ToolDescription, ToolError};

pub struct GoogleSerper {
api_key: String,
}

impl GoogleSerper {
pub fn new(api_key: String) -> Self {
Self { api_key }
}
}

#[derive(Serialize, Deserialize)]
pub struct GoogleSerperInput {
pub query: String,
}

impl From<&str> for GoogleSerperInput {
fn from(value: &str) -> Self {
Self {
query: value.into(),
}
}
}

impl From<String> for GoogleSerperInput {
fn from(value: String) -> Self {
Self { query: value }
}
}

impl Describe for GoogleSerperInput {
fn describe() -> crate::tools::Format {
vec![("query", "Search query to find necessary information").into()].into()
}
}

#[derive(Serialize, Deserialize)]
pub struct GoogleSerperOutput {
pub result: String,
}

impl From<String> for GoogleSerperOutput {
fn from(value: String) -> Self {
Self { result: value }
}
}

impl From<GoogleSerperOutput> for String {
fn from(val: GoogleSerperOutput) -> Self {
val.result
}
}

impl Describe for GoogleSerperOutput {
fn describe() -> crate::tools::Format {
vec![(
"result",
"Information retrieved from the internet that should answer your query",
)
.into()]
.into()
}
}

#[derive(Debug, Serialize, Deserialize)]
struct SiteLinks {
title: String,
link: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Organic {
title: String,
link: String,
snippet: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct GoogleSerperResult {
organic: Vec<Organic>,
}

#[derive(Debug, Error)]
pub enum GoogleSerperError {
#[error("No search results were returned")]
NoResults,
#[error(transparent)]
Yaml(#[from] serde_yaml::Error),
#[error(transparent)]
Request(#[from] reqwest::Error),
}

impl ToolError for GoogleSerperError {}

#[async_trait]
impl Tool for GoogleSerper {
type Input = GoogleSerperInput;

type Output = GoogleSerperOutput;

type Error = GoogleSerperError;

async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
let client = reqwest::Client::new();
let response = client
.request(Method::GET, "https://google.serper.dev/search")
.query(&[("q", &input.query)])
.header("X-API-KEY", self.api_key.clone())
.send()
.await?
.json::<GoogleSerperResult>()
.await?;
let answer = response
.organic
.first()
.ok_or(GoogleSerperError::NoResults)?
.snippet
.clone();
Ok(answer.into())
}

fn description(&self) -> ToolDescription {
ToolDescription::new(
"Google search",
"Useful for when you need to answer questions about current events. Input should be a search query.",
"Use this to get information about current events.",
GoogleSerperInput::describe(),
GoogleSerperOutput::describe(),
)
}
}
2 changes: 2 additions & 0 deletions crates/llm-chain/src/tools/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
mod bash;
mod bing_search;
mod exit;
mod google_serper;
mod python;
mod vectorstore;
pub use bash::{BashTool, BashToolError, BashToolInput, BashToolOutput};
pub use bing_search::{BingSearch, BingSearchError, BingSearchInput, BingSearchOutput};
pub use exit::{ExitTool, ExitToolError, ExitToolInput, ExitToolOutput};
pub use google_serper::{GoogleSerper, GoogleSerperError, GoogleSerperInput, GoogleSerperOutput};
pub use python::{PythonTool, PythonToolError, PythonToolInput, PythonToolOutput};
pub use vectorstore::{
VectorStoreTool, VectorStoreToolError, VectorStoreToolInput, VectorStoreToolOutput,
Expand Down

0 comments on commit c423f22

Please sign in to comment.