Skip to content

Commit

Permalink
feat: implement LL models
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Jan 12, 2025
1 parent 9a01a8b commit e77a097
Show file tree
Hide file tree
Showing 12 changed files with 1,476 additions and 119 deletions.
81 changes: 71 additions & 10 deletions anda_core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ pub trait AgentContext:

/// BaseContext is the core context interface available when calling Agent or Tool.
/// It provides access to various feature sets including:
/// - StateFeatures: User, caller, time, and cancellation token
/// - KeysFeatures: Cryptographic key operations
/// - StoreFeatures: Persistent storage
/// - CacheFeatures: In-memory caching
/// - CanisterFeatures: ICP blockchain interactions
/// - HttpFeatures: HTTP request capabilities
pub trait BaseContext:
Sized
+ StateFeatures<Self::Error>
+ KeysFeatures<Self::Error>
+ StoreFeatures<Self::Error>
+ CacheFeatures<Self::Error>
Expand All @@ -71,7 +73,10 @@ pub trait BaseContext:
{
/// Error type for all context operations
type Error: Into<BoxError>;
}

/// StateFeatures is one of the context feature sets available when calling Agent or Tool.
pub trait StateFeatures<Err>: Sized {
/// Gets the username from request context.
/// Note: This is not verified and should not be used as a trusted identifier.
/// For example, if triggered by a bot of X platform, this might be the username
Expand Down Expand Up @@ -111,22 +116,42 @@ pub struct AgentOutput {
/// Should be None when finish_reason is "stop" or "tool_calls"
pub failed_reason: Option<String>,

/// The function name to call when using Function Calling
pub function: Option<String>,
/// Tool call that this message is responding to. If this message is a response to a tool call, this field should be set to the tool call ID.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,

/// Extracted valid JSON when using Function Calling or JSON output mode
/// If no valid JSON is extracted, the raw content remains in `content`
/// Extracted valid JSON when using JSON response_format
#[serde(skip_serializing_if = "Option::is_none")]
pub extracted_json: Option<Value>,
}

/// Represents a tool call response with it's ID, function name, and arguments
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub args: String,

/// The result of the tool call, processed by agents engine, if available
pub result: Option<Value>,
}

/// Represents a message in the agent's conversation history
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Message {
/// Message role: "system", "user", or "assistant"
pub struct MessageInput {
/// Message role: "developer", "system", "user", "assistant", "tool"
pub role: String,

/// The content of the message
pub content: String,

/// An optional name for the participant. Provides the model information to differentiate between participants of the same role.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,

/// Tool call that this message is responding to. If this message is a response to a tool call, this field should be set to the tool call ID.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}

/// Defines a callable function with its metadata and schema
Expand All @@ -140,6 +165,42 @@ pub struct FunctionDefinition {

/// JSON schema defining the function's parameters
pub parameters: serde_json::Value,

/// Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the parameters field. Only a subset of JSON Schema is supported when strict is true.
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}

/// Struct representing a general completion request that can be sent to a completion model provider.
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider as "developer" or "system" role
pub prompt: String,

/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,

/// The chat history to be sent to the completion model provider
pub chat_history: Vec<MessageInput>,

/// The tools to be sent to the completion model provider
pub tools: Vec<FunctionDefinition>,

/// The temperature to be sent to the completion model provider
pub temperature: Option<f64>,

/// The max tokens to be sent to the completion model provider
pub max_tokens: Option<u64>,

/// An object specifying the JSON format that the model must output.
/// https://platform.openai.com/docs/guides/structured-outputs
/// The format can be one of the following:
/// `{ "type": "json_object" }`
/// `{ "type": "json_schema", "json_schema": {...} }`
pub response_format: Option<Value>,

/// The stop sequence to be sent to the completion model provider
pub stop: Option<Vec<String>>,
}

/// Provides LLM completion capabilities for agents
Expand All @@ -153,10 +214,7 @@ pub trait CompletionFeatures<Err>: Sized {
/// * `tools` - Available functions the model can call
fn completion(
&self,
prompt: &str,
json_output: bool,
chat_history: &[Message],
tools: &[FunctionDefinition],
req: CompletionRequest,
) -> impl Future<Output = Result<AgentOutput, Err>> + Send;
}

Expand All @@ -172,6 +230,9 @@ pub struct Embedding {

/// Provides text embedding capabilities for agents
pub trait EmbeddingFeatures<Err>: Sized {
/// The number of dimensions in the embedding vector.
fn ndims(&self) -> usize;

/// Generates embeddings for multiple texts in a batch
/// Returns a vector of Embedding structs in the same order as input texts
fn embed(
Expand Down
6 changes: 3 additions & 3 deletions anda_core/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ pub async fn cbor_rpc(
body: Vec<u8>,
) -> Result<ByteBuf, HttpRPCError> {
let mut headers = headers.unwrap_or_default();
let cb: http::HeaderValue = CONTENT_TYPE_CBOR.parse().unwrap();
headers.insert(header::CONTENT_TYPE, cb.clone());
headers.insert(header::ACCEPT, cb);
let ct: http::HeaderValue = CONTENT_TYPE_CBOR.parse().unwrap();
headers.insert(header::CONTENT_TYPE, ct.clone());
headers.insert(header::ACCEPT, ct);
let res = client
.post(endpoint)
.headers(headers)
Expand Down
14 changes: 13 additions & 1 deletion anda_core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::ops::{Deref, DerefMut};
use std::{
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
};

pub mod agent;
pub mod context;
Expand All @@ -14,6 +18,9 @@ pub use tool::*;
/// This is commonly used as a return type for functions that can return various error types.
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

/// A type alias for a boxed future that is thread-safe and sendable across threads.
pub type BoxPinFut<T> = Pin<Box<dyn Future<Output = T> + Send>>;

/// A global state manager for Agent or Tool
///
/// Wraps any type `S` to provide shared state management with
Expand All @@ -36,3 +43,8 @@ impl<S> DerefMut for State<S> {
&mut self.0
}
}

/// Joins two paths together
pub fn join_path(a: &Path, b: &Path) -> Path {
Path::from(format!("{}/{}", a, b))
}
71 changes: 55 additions & 16 deletions anda_engine/src/context/agent.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,40 @@
use anda_core::{
AgentContext, AgentOutput, AgentSet, BaseContext, BoxError, CacheExpiry, CacheFeatures,
CancellationToken, CanisterFeatures, CompletionFeatures, Embedding, EmbeddingFeatures,
FunctionDefinition, HttpFeatures, KeysFeatures, Message, ObjectMeta, Path, PutMode, PutResult,
StoreFeatures, ToolSet, Value, VectorSearchFeatures,
CancellationToken, CanisterFeatures, CompletionFeatures, CompletionRequest, Embedding,
EmbeddingFeatures, HttpFeatures, KeysFeatures, ObjectMeta, Path, PutMode, PutResult,
StateFeatures, StoreFeatures, ToolSet, Value, VectorSearchFeatures,
};
use candid::{utils::ArgumentEncoder, CandidType, Principal};
use ciborium::from_reader;
use serde::{de::DeserializeOwned, Serialize};
use std::{future::Future, sync::Arc, time::Duration};

use super::base::BaseCtx;
use crate::{
database::{VectorSearchFeaturesDyn, VectorStore},
model::Model,
};

pub struct AgentCtx {
pub(crate) base: BaseCtx,
pub(crate) model: Model,
pub(crate) store: VectorStore,
pub(crate) tools: Arc<ToolSet<BaseCtx>>,
pub(crate) agents: Arc<AgentSet<AgentCtx>>,
}

impl AgentCtx {
pub fn new(
base: BaseCtx,
model: Model,
store: VectorStore,
tools: Arc<ToolSet<BaseCtx>>,
agents: Arc<AgentSet<AgentCtx>>,
) -> Self {
Self {
base,
model,
store,
tools,
agents,
}
Expand All @@ -32,6 +43,8 @@ impl AgentCtx {
pub fn child(&self, agent_name: &str) -> Result<Self, BoxError> {
Ok(Self {
base: self.base.child(format!("A:{}", agent_name))?,
model: self.model.clone(),
store: self.store.clone(),
tools: self.tools.clone(),
agents: self.agents.clone(),
})
Expand All @@ -51,6 +64,8 @@ impl AgentCtx {
base: self
.base
.child_with(format!("A:{}", agent_name), user, caller)?,
model: self.model.clone(),
store: self.store.clone(),
tools: self.tools.clone(),
agents: self.agents.clone(),
})
Expand Down Expand Up @@ -114,46 +129,70 @@ impl AgentContext for AgentCtx {
}

impl CompletionFeatures<BoxError> for AgentCtx {
async fn completion(
&self,
prompt: &str,
json_output: bool,
chat_history: &[Message],
tools: &[FunctionDefinition],
) -> Result<AgentOutput, BoxError> {
Err("Not implemented".into())
async fn completion(&self, req: CompletionRequest) -> Result<AgentOutput, BoxError> {
let mut res = self.model.completion(req).await?;
// auto call tools
if let Some(tools) = &mut res.tool_calls {
for tool in tools {
if let Ok(args) = serde_json::from_str(&tool.args) {
if let Ok(val) = self.tool_call(&tool.id, &args).await {
tool.result = Some(val);
}
}
}
}

Ok(res)
}
}

impl EmbeddingFeatures<BoxError> for AgentCtx {
fn ndims(&self) -> usize {
self.model.ndims()
}

async fn embed(
&self,
texts: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<Embedding>, BoxError> {
Err("Not implemented".into())
self.model.embed(texts).await
}

async fn embed_query(&self, text: &str) -> Result<Embedding, BoxError> {
Err("Not implemented".into())
self.model.embed_query(text).await
}
}

impl VectorSearchFeatures<BoxError> for AgentCtx {
/// Get the top n documents based on the distance to the given query.
/// The result is a list of tuples of the form (score, id, document)
async fn top_n<T>(&self, query: &str, n: usize) -> Result<Vec<(String, T)>, BoxError> {
Err("Not implemented".into())
async fn top_n<T>(&self, query: &str, n: usize) -> Result<Vec<(String, T)>, BoxError>
where
T: DeserializeOwned,
{
let res = self
.store
.top_n(self.base.path.clone(), query.to_string(), n)
.await?;
Ok(res
.into_iter()
.filter_map(|(id, doc)| from_reader(doc.as_ref()).ok().map(|doc| (id, doc)))
.collect())
}

/// Same as `top_n` but returns the document ids only.
async fn top_n_ids(&self, query: &str, n: usize) -> Result<Vec<String>, BoxError> {
Err("Not implemented".into())
self.store
.top_n_ids(self.base.path.clone(), query.to_string(), n)
.await
}
}

impl BaseContext for AgentCtx {
type Error = BoxError;
}

impl StateFeatures<BoxError> for AgentCtx {
fn user(&self) -> String {
self.base.user()
}
Expand Down
Loading

0 comments on commit e77a097

Please sign in to comment.