diff --git a/src/plugin.rs b/src/plugin.rs index d6e381809..ac58842b6 100644 --- a/src/plugin.rs +++ b/src/plugin.rs @@ -15,22 +15,35 @@ pub mod ffi { }; } +/// Preload config for initializing the wasi_nn plug-in. #[cfg(feature = "wasi_nn")] #[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))] #[derive(Debug)] pub struct NNPreload { /// The alias of the model in the WASI-NN environment. - pub alias: String, + alias: String, /// The inference backend. - pub backend: GraphEncoding, + backend: GraphEncoding, /// The execution target, on which the inference runs. - pub target: ExecutionTarget, + target: ExecutionTarget, /// The path to the model file. Note that the path is the guest path instead of the host path. - pub path: std::path::PathBuf, + path: std::path::PathBuf, } #[cfg(feature = "wasi_nn")] #[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))] impl NNPreload { + /// Creates a new preload config. + /// + /// # Arguments + /// + /// * `alias` - The alias of the model in the WASI-NN environment. + /// + /// * `backend` - The inference backend. + /// + /// * `target` - The execution target, on which the inference runs. + /// + /// * `path` - The path to the model file. Note that the path is the guest path instead of the host path. + /// pub fn new( alias: impl AsRef, backend: GraphEncoding, @@ -59,6 +72,90 @@ impl std::fmt::Display for NNPreload { ) } } +#[cfg(feature = "wasi_nn")] +#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))] +impl std::str::FromStr for NNPreload { + type Err = WasmEdgeError; + + fn from_str(preload: &str) -> std::result::Result { + let nn_preload: Vec<&str> = preload.split(':').collect(); + if nn_preload.len() != 4 { + return Err(WasmEdgeError::Operation(format!( + "Failed to convert to NNPreload value. Invalid preload string: {}. The correct format is: 'alias:backend:target:path'", + preload + ))); + } + let (alias, backend, target, path) = ( + nn_preload[0].to_string(), + nn_preload[1] + .parse::() + .map_err(|err| WasmEdgeError::Operation(err.to_string()))?, + nn_preload[2] + .parse::() + .map_err(|err| WasmEdgeError::Operation(err.to_string()))?, + std::path::PathBuf::from(nn_preload[3]), + ); + + Ok(Self::new(alias, backend, target, path)) + } +} + +#[cfg(feature = "wasi_nn")] +#[test] +fn test_generate_nnpreload_from_str() { + use std::str::FromStr; + + // valid preload string + let preload = "default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf"; + let result = NNPreload::from_str(preload); + assert!(result.is_ok()); + let nnpreload = result.unwrap(); + assert_eq!(nnpreload.alias, "default"); + assert_eq!(nnpreload.backend, GraphEncoding::GGML); + assert_eq!(nnpreload.target, ExecutionTarget::CPU); + assert_eq!( + nnpreload.path, + std::path::PathBuf::from("llama-2-7b-chat.Q5_K_M.gguf") + ); + + // invalid preload string + let preload = "default:CPU:GGML:llama-2-7b-chat.Q5_K_M.gguf"; + let result = NNPreload::from_str(preload); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!( + WasmEdgeError::Operation( + "Failed to convert to NNBackend value. Unknown NNBackend type: CPU".to_string() + ), + err + ); + + // invalid preload string: unsupported target + let preload = "default:GGML:NPU:llama-2-7b-chat.Q5_K_M.gguf"; + let result = NNPreload::from_str(preload); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!( + WasmEdgeError::Operation( + "Failed to convert to ExecutionTarget value. Unknown ExecutionTarget type: NPU" + .to_string() + ), + err + ); + + // invalid preload string: invalid format + let preload = "default:GGML:CPU"; + let result = NNPreload::from_str(preload); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!( + WasmEdgeError::Operation( + "Failed to convert to NNPreload value. Invalid preload string: default:GGML:CPU. The correct format is: 'alias:backend:target:path'" + .to_string() + ), + err + ); +} /// Describes the encoding of the graph. #[cfg(feature = "wasi_nn")] @@ -162,7 +259,7 @@ impl PluginManager { /// * If the path is not given, then the default plugin paths will be used. The default plugin paths are /// /// * The environment variable "WASMEDGE_PLUGIN_PATH". - /// + /// /// * The `../plugin/` directory related to the WasmEdge installation path. /// /// * The `wasmedge/` directory under the library path if the WasmEdge is installed under the "/usr". @@ -186,6 +283,39 @@ impl PluginManager { } } + /// Initialize the wasi_nn plug-in with the preloads. + /// + /// Note that this function is only available after loading the wasi_nn plug-in and before creating, and before creating the module instance from the plug-in. + /// + /// # Argument + /// + /// * `preloads` - The preload list. + /// + /// # Example + /// + /// ```ignore + /// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge + /// PluginManager::load(None)?; + /// // preload named model + /// PluginManager::nn_preload(vec![NNPreload::new( + /// "default", + /// GraphEncoding::GGML, + /// ExecutionTarget::CPU, + /// "llama-2-7b-chat.Q5_K_M.gguf", + /// )]); + /// ``` + /// + /// If a preload is string, then use `NNPreload::from_str` to create a `NNPreload` instance: + /// + /// ```ignore + /// use std::str::FromStr; + /// + /// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge + /// PluginManager::load(None)?; + /// // preload named model + /// PluginManager::nn_preload(vec![NNPreload::from_str("default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf")?]); + /// + /// ``` #[cfg(feature = "wasi_nn")] #[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))] pub fn nn_preload(preloads: Vec) {