Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Implement FromStr trait for NNPreload in wasmedge-sdk #81

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 135 additions & 5 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,35 @@ pub mod ffi {
};
}

/// Preload config for initializing the wasi_nn plug-in.
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug)]
pub struct NNPreload {
/// The alias of the model in the WASI-NN environment.
pub alias: String,
alias: String,
/// The inference backend.
pub backend: GraphEncoding,
backend: GraphEncoding,
/// The execution target, on which the inference runs.
pub target: ExecutionTarget,
target: ExecutionTarget,
/// The path to the model file. Note that the path is the guest path instead of the host path.
pub path: std::path::PathBuf,
path: std::path::PathBuf,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl NNPreload {
/// Creates a new preload config.
///
/// # Arguments
///
/// * `alias` - The alias of the model in the WASI-NN environment.
///
/// * `backend` - The inference backend.
///
/// * `target` - The execution target, on which the inference runs.
///
/// * `path` - The path to the model file. Note that the path is the guest path instead of the host path.
///
pub fn new(
alias: impl AsRef<str>,
backend: GraphEncoding,
Expand Down Expand Up @@ -59,6 +72,90 @@ impl std::fmt::Display for NNPreload {
)
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::str::FromStr for NNPreload {
type Err = WasmEdgeError;

fn from_str(preload: &str) -> std::result::Result<Self, Self::Err> {
let nn_preload: Vec<&str> = preload.split(':').collect();
if nn_preload.len() != 4 {
return Err(WasmEdgeError::Operation(format!(
"Failed to convert to NNPreload value. Invalid preload string: {}. The correct format is: 'alias:backend:target:path'",
preload
)));
}
let (alias, backend, target, path) = (
nn_preload[0].to_string(),
nn_preload[1]
.parse::<GraphEncoding>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
nn_preload[2]
.parse::<ExecutionTarget>()
.map_err(|err| WasmEdgeError::Operation(err.to_string()))?,
std::path::PathBuf::from(nn_preload[3]),
);

Ok(Self::new(alias, backend, target, path))
}
}

#[cfg(feature = "wasi_nn")]
#[test]
fn test_generate_nnpreload_from_str() {
use std::str::FromStr;

// valid preload string
let preload = "default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_ok());
let nnpreload = result.unwrap();
assert_eq!(nnpreload.alias, "default");
assert_eq!(nnpreload.backend, GraphEncoding::GGML);
assert_eq!(nnpreload.target, ExecutionTarget::CPU);
assert_eq!(
nnpreload.path,
std::path::PathBuf::from("llama-2-7b-chat.Q5_K_M.gguf")
);

// invalid preload string
let preload = "default:CPU:GGML:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNBackend value. Unknown NNBackend type: CPU".to_string()
),
err
);

// invalid preload string: unsupported target
let preload = "default:GGML:NPU:llama-2-7b-chat.Q5_K_M.gguf";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to ExecutionTarget value. Unknown ExecutionTarget type: NPU"
.to_string()
),
err
);

// invalid preload string: invalid format
let preload = "default:GGML:CPU";
let result = NNPreload::from_str(preload);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(
WasmEdgeError::Operation(
"Failed to convert to NNPreload value. Invalid preload string: default:GGML:CPU. The correct format is: 'alias:backend:target:path'"
.to_string()
),
err
);
}

/// Describes the encoding of the graph.
#[cfg(feature = "wasi_nn")]
Expand Down Expand Up @@ -162,7 +259,7 @@ impl PluginManager {
/// * If the path is not given, then the default plugin paths will be used. The default plugin paths are
///
/// * The environment variable "WASMEDGE_PLUGIN_PATH".
///
///
/// * The `../plugin/` directory related to the WasmEdge installation path.
///
/// * The `wasmedge/` directory under the library path if the WasmEdge is installed under the "/usr".
Expand All @@ -186,6 +283,39 @@ impl PluginManager {
}
}

/// Initialize the wasi_nn plug-in with the preloads.
///
/// Note that this function is only available after loading the wasi_nn plug-in and before creating, and before creating the module instance from the plug-in.
///
/// # Argument
///
/// * `preloads` - The preload list.
///
/// # Example
///
/// ```ignore
/// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge
/// PluginManager::load(None)?;
/// // preload named model
/// PluginManager::nn_preload(vec![NNPreload::new(
/// "default",
/// GraphEncoding::GGML,
/// ExecutionTarget::CPU,
/// "llama-2-7b-chat.Q5_K_M.gguf",
/// )]);
/// ```
///
/// If a preload is string, then use `NNPreload::from_str` to create a `NNPreload` instance:
///
/// ```ignore
/// use std::str::FromStr;
///
/// // load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge
/// PluginManager::load(None)?;
/// // preload named model
/// PluginManager::nn_preload(vec![NNPreload::from_str("default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf")?]);
///
/// ```
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
pub fn nn_preload(preloads: Vec<NNPreload>) {
Expand Down