-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #221 from danbev/google-serper
feat: add google serper (search) tool
- Loading branch information
Showing
10 changed files
with
1,350 additions
and
529 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
crates/llm-chain-openai/examples/self_ask_with_google_search.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use llm_chain::{ | ||
agents::self_ask_with_search::{Agent, EarlyStoppingConfig}, | ||
executor, | ||
tools::tools::GoogleSerper, | ||
}; | ||
|
||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() { | ||
let executor = executor!().unwrap(); | ||
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap(); | ||
let search_tool = GoogleSerper::new(serper_api_key); | ||
let agent = Agent::new( | ||
executor, | ||
search_tool, | ||
EarlyStoppingConfig { | ||
max_iterations: Some(10), | ||
max_time_elapsed_seconds: Some(30.0), | ||
}, | ||
); | ||
let (res, intermediate_steps) = agent | ||
.run("What is the capital of the birthplace of Levy Mwanawasa?") | ||
.await | ||
.unwrap(); | ||
println!( | ||
"Are followup questions needed here: {}", | ||
agent.build_agent_scratchpad(&intermediate_steps) | ||
); | ||
println!( | ||
"Agent final answer: {}", | ||
res.return_values.get("output").unwrap() | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
use llm_chain::tools::{tools::GoogleSerper, Tool}; | ||
|
||
#[tokio::main(flavor = "current_thread")] | ||
async fn main() { | ||
let serper_api_key = std::env::var("SERPER_API_KEY").unwrap(); | ||
let serper = GoogleSerper::new(serper_api_key); | ||
let result = serper | ||
.invoke_typed(&"Who was the inventor of Catan?".into()) | ||
.await | ||
.unwrap(); | ||
println!("Best answer from Google Serper: {}", result.result); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
use async_trait::async_trait; | ||
use reqwest::Method; | ||
use serde::{Deserialize, Serialize}; | ||
use thiserror::Error; | ||
|
||
use crate::tools::{Describe, Tool, ToolDescription, ToolError}; | ||
|
||
pub struct GoogleSerper { | ||
api_key: String, | ||
} | ||
|
||
impl GoogleSerper { | ||
pub fn new(api_key: String) -> Self { | ||
Self { api_key } | ||
} | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct GoogleSerperInput { | ||
pub query: String, | ||
} | ||
|
||
impl From<&str> for GoogleSerperInput { | ||
fn from(value: &str) -> Self { | ||
Self { | ||
query: value.into(), | ||
} | ||
} | ||
} | ||
|
||
impl From<String> for GoogleSerperInput { | ||
fn from(value: String) -> Self { | ||
Self { query: value } | ||
} | ||
} | ||
|
||
impl Describe for GoogleSerperInput { | ||
fn describe() -> crate::tools::Format { | ||
vec![("query", "Search query to find necessary information").into()].into() | ||
} | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct GoogleSerperOutput { | ||
pub result: String, | ||
} | ||
|
||
impl From<String> for GoogleSerperOutput { | ||
fn from(value: String) -> Self { | ||
Self { result: value } | ||
} | ||
} | ||
|
||
impl From<GoogleSerperOutput> for String { | ||
fn from(val: GoogleSerperOutput) -> Self { | ||
val.result | ||
} | ||
} | ||
|
||
impl Describe for GoogleSerperOutput { | ||
fn describe() -> crate::tools::Format { | ||
vec![( | ||
"result", | ||
"Information retrieved from the internet that should answer your query", | ||
) | ||
.into()] | ||
.into() | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct SiteLinks { | ||
title: String, | ||
link: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Organic { | ||
title: String, | ||
link: String, | ||
snippet: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct GoogleSerperResult { | ||
organic: Vec<Organic>, | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum GoogleSerperError { | ||
#[error("No search results were returned")] | ||
NoResults, | ||
#[error(transparent)] | ||
Yaml(#[from] serde_yaml::Error), | ||
#[error(transparent)] | ||
Request(#[from] reqwest::Error), | ||
} | ||
|
||
impl ToolError for GoogleSerperError {} | ||
|
||
#[async_trait] | ||
impl Tool for GoogleSerper { | ||
type Input = GoogleSerperInput; | ||
|
||
type Output = GoogleSerperOutput; | ||
|
||
type Error = GoogleSerperError; | ||
|
||
async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> { | ||
let client = reqwest::Client::new(); | ||
let response = client | ||
.request(Method::GET, "https://google.serper.dev/search") | ||
.query(&[("q", &input.query)]) | ||
.header("X-API-KEY", self.api_key.clone()) | ||
.send() | ||
.await? | ||
.json::<GoogleSerperResult>() | ||
.await?; | ||
let answer = response | ||
.organic | ||
.first() | ||
.ok_or(GoogleSerperError::NoResults)? | ||
.snippet | ||
.clone(); | ||
Ok(answer.into()) | ||
} | ||
|
||
fn description(&self) -> ToolDescription { | ||
ToolDescription::new( | ||
"Google search", | ||
"Useful for when you need to answer questions about current events. Input should be a search query.", | ||
"Use this to get information about current events.", | ||
GoogleSerperInput::describe(), | ||
GoogleSerperOutput::describe(), | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters