Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

add bert model #398

Merged
merged 1 commit into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ impl Context {
pub fn storage(&self) -> &ContextStorage {
self.storage.as_ref().unwrap()
}

/// Set all values of the tensor with the specified value.
pub fn set_f32(&self, a: &Tensor, x: f32) -> Tensor {
let raw = unsafe { sys::ggml_set_f32(a.ptr.as_ptr(), x) };
self.new_tensor_raw(raw)
}
}
// Operations
impl Context {
Expand Down Expand Up @@ -598,6 +604,30 @@ impl Context {
};
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with the square of `a`
pub fn op_sqr(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sqr(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with the square-root of `a`
pub fn op_sqrt(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sqrt(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Unknown
pub fn op_sum(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_sum(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Unknown
pub fn op_div(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_div(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}
}
// Public to this crate methods
impl Context {
Expand Down
4 changes: 3 additions & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" }
llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" }
llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" }
llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" }
llm-bert = { path = "../models/bert", optional = true, version = "0.2.0-dev" }

serde = { workspace = true }
tracing = { workspace = true }
Expand All @@ -34,13 +35,14 @@ default = ["models", "tokenizers-remote"]

tokenizers-remote = ["llm-base/tokenizers-remote"]

models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]
models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "bert"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
gptneox = ["dep:llm-gptneox"]
mpt = ["dep:llm-mpt"]
bert = ["dep:llm-bert"]
# Falcon is off by default. See `llm_falcon`'s module documentation for more information.
falcon = ["dep:llm-falcon"]

Expand Down
1 change: 1 addition & 0 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ macro_rules! define_models {
}

define_models!(
(bert, "bert", Bert, llm_bert, "Bert"),
(bloom, "bloom", Bloom, llm_bloom, "BLOOM"),
(gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"),
(gptj, "gptj", GptJ, llm_gptj, "GPT-J"),
Expand Down
14 changes: 14 additions & 0 deletions crates/models/bert/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "llm-bert"
version = "0.2.0-dev"
license = { workspace = true }
repository = { workspace = true }
description = "An implementation of BERT for the `llm` ecosystem."
edition = "2021"
readme = "../../../README.md"

[dependencies]
bytemuck.workspace = true
llm-base = { path = "../../llm-base", version = "0.2.0-dev" }
tracing = { version = "0.1", features = ["log"] }

Loading
Loading