Skip to content

Commit

Permalink
Add CompatibilityTensorRetrievalBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 1, 2024
1 parent 4518c8e commit 83ec0fd
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 4 deletions.
29 changes: 25 additions & 4 deletions candle-holder-models/src/from_pretrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use candle_holder::{
};

use crate::generation::config::GenerationConfig;
use crate::utils::var_builder::CompatibilityTensorRetrievalBackend;

const MODEL_GENERATION_CONFIG_FILE: &str = "generation_config.json";
const MODEL_SAFETENSORS_INDEX_FILE: &str = "model.safetensors.index.json";
Expand Down Expand Up @@ -48,13 +49,33 @@ impl ModelInfo {
///
/// A `VarBuilder` containing the model weights.
pub fn get_var_builder(&self, dtype: DType, device: &Device) -> Result<VarBuilder> {
let vb = match self.from_pth {
true => VarBuilder::from_pth(&self.weights_file_paths[0], dtype, device)?,
let model_name = self.get_model_name();
let backend = match self.from_pth {
true => CompatibilityTensorRetrievalBackend::from_pth(
&self.weights_file_paths[0],
model_name,
)?,
false => unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_file_paths, dtype, device)?
CompatibilityTensorRetrievalBackend::from_mmaped_safetensors(
&self.weights_file_paths,
model_name,
)?
},
};
Ok(vb)
Ok(VarBuilder::from_backend(
Box::new(backend),
dtype,
device.clone(),
))
}

pub fn get_model_name(&self) -> String {
self.config
.as_ref()
.and_then(|config| config.get("model_type"))
.and_then(|model_type| model_type.as_str())
.unwrap_or_default()
.to_string()
}

/// Gets a reference to the model configuration.
Expand Down
1 change: 1 addition & 0 deletions candle-holder-models/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod attn_mask;
pub mod cache;
pub mod flash_attn;
pub mod rope;
pub mod var_builder;
114 changes: 114 additions & 0 deletions candle-holder-models/src/utils/var_builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use candle_core::{DType, Device, Shape, Tensor};
use candle_nn::{var_builder::SimpleBackend, Init};

/// A backend for retrieving tensors that ensures compatibility with old tensor naming conventions.
/// This backend is able to handle the following cases:
///
/// 1. The model prefix is missing from the tensor name.
/// 2. Tensors are named as `beta` and `gamma` instead of `weight` and `bias`.
///
/// This struct wraps a `SimpleBackend` implementation and provides an additional
/// `model_name` field to support model-specific tensor retrieval operations.
pub struct CompatibilityTensorRetrievalBackend {
inner: Box<dyn SimpleBackend>,
model_name: String,
}

impl CompatibilityTensorRetrievalBackend {
pub fn new(inner: Box<dyn SimpleBackend>, model_name: String) -> Self {
Self { inner, model_name }
}

/// Create a new `CompatibilityTensorRetrievalBackend` from a `PthTensors` instance that reads tensors from a `.pth` file.
///
/// # Arguments
///
/// * `p` - The path to the `.pth` file.
/// * `model_name` - The name of the model.
///
/// # Returns
///
/// A `CompatibilityTensorRetrievalBackend
pub fn from_pth<P: AsRef<std::path::Path>>(
p: P,
model_name: String,
) -> candle_core::Result<Self> {
let pth = candle_core::pickle::PthTensors::new(p, None)?;
Ok(Self {
inner: Box::new(pth),
model_name,
})
}

/// Create a new `CompatibilityTensorRetrievalBackend` from a `MmapedSafetensors` instance that reads tensors from a `.safetensors` file.
///
/// # Arguments
///
/// * `paths` - A list of paths to the `.safetensors` files.
/// * `model_name` - The name of the model.
///
/// # Returns
///
/// A `CompatibilityTensorRetrievalBackend
pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
paths: &[P],
model_name: String,
) -> candle_core::Result<Self> {
let tensors = candle_core::safetensors::MmapedSafetensors::multi(paths)?;
Ok(Self {
inner: Box::new(tensors),
model_name,
})
}

fn rename(&self, name: &str) -> String {
// Check if the original name exists
if self.inner.contains_tensor(name) {
return name.to_string();
}

// Try removing the model name prefix
let without_prefix = name
.strip_prefix(&format!("{}.", self.model_name))
.unwrap_or(name);

// Function to replace weight/bias with beta/gamma
let replace_weight_bias = |s: &str| s.replace("weight", "beta").replace("bias", "gamma");

// Generate all possible combinations
let possible_names = [
without_prefix.to_string(),
replace_weight_bias(name),
replace_weight_bias(without_prefix),
];

// Find the first name that exists in the tensor
for possible_name in possible_names.iter() {
if self.inner.contains_tensor(possible_name) {
return possible_name.to_string();
}
}

// If no matching tensor is found, return the original name
name.to_string()
}
}

impl SimpleBackend for CompatibilityTensorRetrievalBackend {
fn get(
&self,
s: Shape,
name: &str,
h: Init,
dtype: DType,
dev: &Device,
) -> candle_core::Result<Tensor> {
let renamed = self.rename(name);
self.inner.get(s, &renamed, h, dtype, dev)
}

fn contains_tensor(&self, name: &str) -> bool {
let renamed = self.rename(name);
self.inner.contains_tensor(&renamed)
}
}

0 comments on commit 83ec0fd

Please sign in to comment.