diff --git a/Cargo.lock b/Cargo.lock index 41a06be9cf..a5b83899b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3174,7 +3174,12 @@ version = "0.2.0-dev" source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" dependencies = [ "llm-base", + "llm-bloom", + "llm-gpt2", + "llm-gptj", + "llm-gptneox", "llm-llama", + "llm-mpt", "serde", "tracing", ] @@ -3199,6 +3204,39 @@ dependencies = [ "tracing", ] +[[package]] +name = "llm-bloom" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" +dependencies = [ + "llm-base", +] + +[[package]] +name = "llm-gpt2" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" +dependencies = [ + "bytemuck", + "llm-base", +] + +[[package]] +name = "llm-gptj" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" +dependencies = [ + "llm-base", +] + +[[package]] +name = "llm-gptneox" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" +dependencies = [ + "llm-base", +] + [[package]] name = "llm-llama" version = "0.2.0-dev" @@ -3208,6 +3246,14 @@ dependencies = [ "tracing", ] +[[package]] +name = "llm-mpt" +version = "0.2.0-dev" +source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" +dependencies = [ + "llm-base", +] + [[package]] name = "llm-samplers" version = "0.0.6" diff --git a/crates/llm-local/src/lib.rs b/crates/llm-local/src/lib.rs index 90cce3235e..aa30102e24 100644 --- a/crates/llm-local/src/lib.rs +++ b/crates/llm-local/src/lib.rs @@ -6,11 +6,11 @@ use candle::DType; use candle_nn::VarBuilder; use llm::{ InferenceFeedback, InferenceParameters, InferenceResponse, InferenceSessionConfig, Model, - ModelKVMemoryType, ModelParameters, + ModelArchitecture, ModelKVMemoryType, ModelParameters, }; use rand::SeedableRng; use spin_core::async_trait; -use spin_llm::{model_arch, model_name, LlmEngine, MODEL_ALL_MINILM_L6_V2}; +use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2}; use spin_world::llm::{self as wasi_llm}; use std::{ collections::hash_map::Entry, @@ -170,14 +170,22 @@ impl LocalLlmEngine { &mut self, model: wasi_llm::InferencingModel, ) -> Result, wasi_llm::Error> { - let model_name = model_name(&model)?; let use_gpu = self.use_gpu; let progress_fn = |_| {}; - let model = match self.inferencing_models.entry((model_name.into(), use_gpu)) { + let model = match self.inferencing_models.entry((model.clone(), use_gpu)) { Entry::Occupied(o) => o.get().clone(), Entry::Vacant(v) => v .insert({ - let path = self.registry.join(model_name); + let (path, arch) = if let Some(arch) = well_known_inferencing_model_arch(&model) { + let model_binary = self.registry.join(&model); + if model_binary.exists() { + (model_binary, arch.to_owned()) + } else { + walk_registry_for_model(&self.registry, model).await? + } + } else { + walk_registry_for_model(&self.registry, model).await? + }; if !self.registry.exists() { return Err(wasi_llm::Error::RuntimeError( format!("The directory expected to house the inferencing model '{}' does not exist.", self.registry.display()) @@ -199,7 +207,7 @@ impl LocalLlmEngine { n_gqa: None, }; let model = llm::load_dynamic( - Some(model_arch(&model)?), + Some(arch), &path, llm::TokenizerSource::Embedded, params, @@ -223,6 +231,80 @@ impl LocalLlmEngine { } } +/// Get the model binary and arch from walking the registry file structure +async fn walk_registry_for_model( + registry_path: &Path, + model: String, +) -> Result<(PathBuf, ModelArchitecture), wasi_llm::Error> { + let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Could not read model registry directory '{}': {e}", + registry_path.display() + )) + })?; + let mut result = None; + 'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Failed to read arch directory in model registry: {e}" + )) + })? { + if arch_dir + .file_type() + .await + .map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Could not read file type of '{}' dir: {e}", + arch_dir.path().display() + )) + })? + .is_file() + { + continue; + } + let mut model_files = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Error reading architecture directory in model registry: {e}" + )) + })?; + while let Some(model_file) = model_files.next_entry().await.map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Error reading model file in model registry: {e}" + )) + })? { + if model_file + .file_name() + .to_str() + .map(|m| m == model) + .unwrap_or_default() + { + let arch = arch_dir.file_name(); + let arch = arch + .to_str() + .ok_or(wasi_llm::Error::ModelNotSupported)? + .parse() + .map_err(|_| wasi_llm::Error::ModelNotSupported)?; + result = Some((model_file.path(), arch)); + break 'outer; + } + } + } + + result.ok_or_else(|| { + wasi_llm::Error::InvalidInput(format!( + "no model directory found in registry for model '{model}'" + )) + }) +} + +fn well_known_inferencing_model_arch( + model: &wasi_llm::InferencingModel, +) -> Option { + match model.as_str() { + "llama2-chat" | "code_llama" => Some(ModelArchitecture::Llama), + _ => None, + } +} + async fn generate_embeddings( data: Vec, model: Arc<(tokenizers::Tokenizer, BertModel)>, diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 477912f5f4..08c193e0b6 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -9,7 +9,7 @@ anyhow = "1.0" bytesize = "1.1" llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [ "tokenizers-remote", - "llama", + "models", ], default-features = false } spin-app = { path = "../app" } spin-core = { path = "../core" } diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 804120cf67..c948f3a95e 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -1,6 +1,5 @@ pub mod host_component; -use llm::ModelArchitecture; use spin_app::MetadataKey; use spin_core::async_trait; use spin_world::llm::{self as wasi_llm}; @@ -72,22 +71,6 @@ impl wasi_llm::Host for LlmDispatch { } } -pub fn model_name(model: &wasi_llm::InferencingModel) -> Result<&str, wasi_llm::Error> { - match model.as_str() { - "llama2-chat" | "codellama-instruct" => Ok(model.as_str()), - _ => Err(wasi_llm::Error::ModelNotSupported), - } -} - -pub fn model_arch( - model: &wasi_llm::InferencingModel, -) -> Result { - match model.as_str() { - "llama2-chat" | "codellama-instruct" => Ok(ModelArchitecture::Llama), - _ => Err(wasi_llm::Error::ModelNotSupported), - } -} - fn access_denied_error(model: &str) -> wasi_llm::Error { wasi_llm::Error::InvalidInput(format!( "The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"