Skip to content

Commit

Permalink
Fix quant.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Oct 17, 2023
1 parent ca11745 commit a3301ac
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
2 changes: 1 addition & 1 deletion assets/configs/Config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[model]
path = "assets/models/RWKV-4-World-0.4B-v1-20230529-ctx4096.st" # Path to the model.
quant = [] # Layers to be quantized.
quant = 0 # Layers to be quantized.
token_chunk_size = 32 # Size of token chunk that is inferred at once. For high end GPUs, this could be 64 or 128 (faster).
head_chunk_size = 8192 # DO NOT modify this if you don't know what you are doing.
max_runtime_batch = 8 # The maximum batches that can be scheduled for inference at the same time.
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct ModelConfig {
/// Path to the model.
pub path: PathBuf,
/// Specify layers that needs to be quantized.
pub quant: Vec<usize>,
pub quant: usize,
/// Maximum tokens to be processed in parallel at once.
pub token_chunk_size: usize,
/// The chunk size for each split of the head matrix.
Expand Down
13 changes: 4 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub struct ReloadRequest {
/// Path to the model.
pub model_path: PathBuf,
/// Specify layers that needs to be quantized.
pub quant: Vec<usize>,
pub quant: usize,
/// Maximum tokens to be processed in parallel at once.
pub token_chunk_size: usize,
/// The chunk size for each split of the head matrix.
Expand Down Expand Up @@ -209,14 +209,9 @@ where
head_chunk_size,
..
} = request;
let quant = if quant.is_empty() {
Quantization::None
} else {
let mut layers = LayerFlags::empty();
quant
.into_iter()
.for_each(|x| layers.insert(LayerFlags::from_layer(x as u64)));
Quantization::Int8(layers)
let quant = match quant {
0 => Quantization::None,
x => Quantization::Int8(LayerFlags::from_bits_retain((1 << x) - 1)),
};

let model: M = ModelBuilder::new(context, data)
Expand Down

0 comments on commit a3301ac

Please sign in to comment.