Skip to content

Commit

Permalink
[feat] Implement FromStr trait for NNPreload in wasmedge-sdk (#81)
Browse files Browse the repository at this point in the history
* feat(rust-sdk): implement `FromStr` trait for `NNPreload`

Signed-off-by: Xin Liu <[email protected]>

* chore(rust-sdk): add the checking code for `NNPreload::from_str`

Signed-off-by: Xin Liu <[email protected]>

---------

Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss authored Oct 26, 2023
1 parent 37cc270 commit 3146e89
Showing 1 changed file with 135 additions and 5 deletions.
140 changes: 135 additions & 5 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,35 @@ pub mod ffi {
};
}

/// Preload config for initializing the wasi_nn plug-in.
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug)]
pub struct NNPreload {
/// The alias of the model in the WASI-NN environment.
pub alias: String,
alias: String,
/// The inference backend.
pub backend: GraphEncoding,
backend: GraphEncoding,
/// The execution target, on which the inference runs.
pub target: ExecutionTarget,
target: ExecutionTarget,
/// The path to the model file. Note that the path is the guest path instead of the host path.
pub path: std::path::PathBuf,
path: std::path::PathBuf,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl NNPreload {
/// Creates a new preload config.
///
/// # Arguments
///
/// * `alias` - The alias of the model in the WASI-NN environment.
///
/// * `backend` - The inference backend.
///
/// * `target` - The execution target, on which the inference runs.
///
/// * `path` - The path to the model file. Note that the path is the guest path instead of the host path.
///
pub fn new(
alias: impl AsRef<str>,
backend: GraphEncoding,
Expand Down Expand Up @@ -59,6 +72,90 @@ impl std::fmt::Display for NNPreload {
)
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::str::FromStr for NNPreload {
type Err = WasmEdgeError;

fn from_str(preload: &str) -> std::result::Result<Self, Self::Err> {
let nn_preload: Vec<&str> = preload.split(':').collect();
if nn_preload.len() != 4 {
return Err(WasmEdgeError::Operation(format!(
"Failed to convert to NNPreload value. Invalid preload string: {}. The correct format is: 'alias:backend:target:path'",
preload
)));
}
let (alias, backend, target, path) = (
nn_preload[0].to_string(),
nn_preload[1]
.parse::<GraphEncoding>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
nn_preload[2]
.parse::<ExecutionTarget>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
std::path::PathBuf::from(nn_preload[3]),
);

Ok(Self::new(alias, backend, target, path))
}
}

#[cfg(feature = "wasi_nn")]
#[test]
fn test_generate_nnpreload_from_str() {
use std::str::FromStr;

// valid preload string
let preload = "default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_ok());
let nnpreload = result.unwrap();
assert_eq!(nnpreload.alias, "default");
assert_eq!(nnpreload.backend, GraphEncoding::GGML);
assert_eq!(nnpreload.target, ExecutionTarget::CPU);
assert_eq!(
nnpreload.path,
std::path::PathBuf::from("llama-2-7b-chat.Q5_K_M.gguf")
);

// invalid preload string
let preload = "default:CPU:GGML:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNBackend value. Unknown NNBackend type: CPU".to_string()
),
err
);

// invalid preload string: unsupported target
let preload = "default:GGML:NPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to ExecutionTarget value. Unknown ExecutionTarget type: NPU"
.to_string()
),
err
);

// invalid preload string: invalid format
let preload = "default:GGML:CPU";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNPreload value. Invalid preload string: default:GGML:CPU. The correct format is: 'alias:backend:target:path'"
.to_string()
),
err
);
}

/// Describes the encoding of the graph.
#[cfg(feature = "wasi_nn")]
Expand Down Expand Up @@ -162,7 +259,7 @@ impl PluginManager {
/// * If the path is not given, then the default plugin paths will be used. The default plugin paths are
///
/// * The environment variable "WASMEDGE_PLUGIN_PATH".
///
///
/// * The `../plugin/` directory related to the WasmEdge installation path.
///
/// * The `wasmedge/` directory under the library path if the WasmEdge is installed under the "/usr".
Expand All @@ -186,6 +283,39 @@ impl PluginManager {
}
}

/// Initialize the wasi_nn plug-in with the preloads.
///
/// Note that this function is only available after loading the wasi_nn plug-in and before creating, and before creating the module instance from the plug-in.
///
/// # Argument
///
/// * `preloads` - The preload list.
///
/// # Example
///
/// ```ignore
/// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge
/// PluginManager::load(None)?;
/// // preload named model
/// PluginManager::nn_preload(vec![NNPreload::new(
/// "default",
/// GraphEncoding::GGML,
/// ExecutionTarget::CPU,
/// "llama-2-7b-chat.Q5_K_M.gguf",
/// )]);
/// ```
///
/// If a preload is string, then use `NNPreload::from_str` to create a `NNPreload` instance:
///
/// ```ignore
/// use std::str::FromStr;
///
/// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge
/// PluginManager::load(None)?;
/// // preload named model
/// PluginManager::nn_preload(vec![NNPreload::from_str("default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf")?]);
///
/// ```
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
pub fn nn_preload(preloads: Vec<NNPreload>) {
Expand Down

0 comments on commit 3146e89

Please sign in to comment.