-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specify proc blocks more rigorously #328
Comments
Some things that should be defined...
|
An example of how you might implement a tokenizer proc block under this scheme: /// A BERT tokenizer.
#[derive(ProcBlock)]
struct Tokenizer {
word_list: Vec<&'static str>,
}
#[arguments]
impl Tokenizer {
#[argument]
pub fn set_word_list(&mut self, value: &'static str) -> Result<(), Infallible> {
self.word_list = value.lines().map(|line| line.trim()).filter(|line| !line.is_empty()).collect();
Ok(())
}
fn tokenize(&self, sentence: &str) -> (Tensor<i32>, Tensor<i32>, Tensor<i32>) { ... }
}
#[transform(inputs = utf8, outputs = (i32[_], i32[_], i32[_]))]
impl Transform<Tensor<Cow<'static, str>> for Tokenizer {
type Output = (Tensor<i32>, Tensor<i32>, Tensor<i32>);
fn transform(&mut self, input: Tensor<Cow<'static, str>>) -> Self::Output {
assert_eq!(input.dimensions(), &[1], "This proc block only accepts a tensor containing a single string");
let sentence = input.get(&[0]).unwrap();
self.tokenize(sentence);
...
}
}
#[transform(inputs = u8[_])]
impl Transform<Tensor<u8>> for Tokenizer {
type Output = (Tensor<i32>, Tensor<i32>, Tensor<i32>);
fn transform(&mut self, input: Tensor<Cow<'static, str>>) -> Self::Output {
assert_eq!(input.dimensions().len(), 1, "This proc block only accepts 1D tensors");
let sentence: &[u8] = input.elements();
let sentence: &str = core::str::from_utf8(sentence).expect("The input was invalid UTF8");
self.tokenize(sentence);
...
}
} |
In terms of documentation and examples, I think most of this would be done in doc-comments on the corresponding procedural macros. That way we can include loads of examples which |
After playing around with Forge a bit more, I think the extra type safety we get by It's great to get errors from the compiler when you are writing Rust, but a typical Forge user is several steps removed from the Rust source code being compiled. Instead, we should aim for a single all-encompassing interface which takes a list of tensors as inputs and returns a list of tensors. The tensors should also do type checking internally instead of using a generic type parameter. Among other things, this will let us remove the arbitrary restrictions on max inputs/outputs because they can be stored in a slice (e.g. |
Currently, proc blocks are mostly implemented on an ad-hoc basis with a lot of work left up to the Rust compiler to catch bugs.
We want to take advantage of Rust's procedural macros to enforce a consistent structure and generate metadata, then use WebAssembly custom sections to give external programs access to that metadata without needing to execute the unknown proc block. This is currently done using
#[derive(ProcBlock)]
.As it is, while we've been generating this metadata for a while, it isn't actually used by anything. That means the way a proc block is implemented and the way it is used can diverge, causing cryptic compilation errors because
rune build
blindly generates invalid code.There are roughly 3 pieces of information external programs need to know about a proc block:
width
) and can be set to a string (see Arguments in a runefile should just be plain strings #237)Later on, we may also include things a proc block requires from the runtime in order to run (e.g.
extern "C"
functions for hardware-accelerated operations).The text was updated successfully, but these errors were encountered: