Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Retrieving Named Attributes #913

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/wrappers/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use super::utils::{path_to_cstring, ptr_to_string};
use super::{device::Device, kind::Kind};
use crate::{nn::Path, TchError, Tensor};
use libc::{c_int, c_void};
use libc::{c_char, c_int, c_void};
use std::borrow::Borrow;
use std::convert::TryFrom;
use torch_sys::*;
Expand Down Expand Up @@ -403,6 +403,14 @@ impl IValue {
}
Ok(v)
}

pub(super) extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_ivalue: *mut CIValue) {
let v = unsafe { &mut *(data as *mut Vec<(String, Result<IValue, TchError>)>) };
let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
let name = name.replace('|', ".");
let value = Self::from_c(c_ivalue);
v.push((name, value));
}
}

/// A jit PyTorch module.
Expand Down Expand Up @@ -593,6 +601,17 @@ impl CModule {
Ok(v)
}

/// Loads the named attributes on a module
pub fn named_attributes(&self) -> Result<Vec<(String, IValue)>, TchError> {
let mut v: Vec<(String, Result<IValue, TchError>)> = vec![];
unsafe_torch_err!(atm_named_attributes(
self.c_module,
&mut v as *mut _ as *mut c_void,
IValue::add_callback
));
v.into_iter().map(|(k, v)| Ok((k, v?))).collect()
}

/// Create a new module by tracing the application of the specified function on
/// the given inputs.
pub fn create_by_tracing<F>(
Expand Down
Binary file added tests/foo9.pt
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/jit_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ fn jit4() {
assert_eq!(v, 14.0);
let named_parameters = mod_.named_parameters().unwrap();
assert_eq!(named_parameters, vec![]);

// Even "empty" models have some attributes.
let named_attributes = mod_.named_attributes().unwrap();
assert_eq!(named_attributes, vec![
(String::from("training"), IValue::Bool(true)),
(String::from("_is_full_backward_hook"), IValue::None),
]);
}

#[test]
fn jit_named_attributes() {
// Check that models with user-defined attributes are correctly loaded.
let mod_ = tch::CModule::load("tests/foo9.pt").unwrap();
let named_attributes = mod_.named_attributes().unwrap();
assert!(named_attributes.len() > 2);
assert!(named_attributes.iter().any(|(name, _)| name == "embedding_length"));
assert!(named_attributes.iter().any(|(name, _)| name == "max_distance"));
}

#[test]
Expand Down
9 changes: 9 additions & 0 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,15 @@ void atm_named_parameters(module m, void *data, void (*f)(void *, char *, tensor
)
}

void atm_named_attributes(module m, void *data, void (*f)(void *, char *, ivalue)) {
PROTECT(
for (const auto &p : m->named_attributes()) {
auto v = p.value;
f(data, (char*)p.name.c_str(), new torch::IValue(v));
}
)
}

ivalue ati_tensor(tensor t) {
PROTECT(
return new torch::jit::IValue(*t);
Expand Down
1 change: 1 addition & 0 deletions torch-sys/libtch/torch_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ void atm_set_profiling_mode(int);
void atm_fuser_cuda_set_enabled(bool);
bool atm_fuser_cuda_is_enabled();
void atm_named_parameters(module, void *data, void (*f)(void *, char *, tensor));
void atm_named_attributes(module, void *data, void (*f)(void *, char *, ivalue));

// This function has to be followed by a call to atm_end_tracing.
module atm_create_for_tracing(char *modl_name, tensor *inputs, int ninputs);
Expand Down
5 changes: 5 additions & 0 deletions torch-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ extern "C" {
data: *mut c_void,
f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
);
pub fn atm_named_attributes(
m: *mut CModule_,
data: *mut c_void,
f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut CIValue),
);
pub fn atm_create_for_tracing(
modl_name: *const c_char,
inputs: *const *mut C_tensor,
Expand Down