Skip to content

Commit

Permalink
Add more checks and panics
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 15, 2024
1 parent 02b5260 commit 08cf24f
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,23 +747,28 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn check_inputs(node: &Node) {
if node.inputs.len() == 1 {
panic!("Pad: must provide at least two inputs")
}
}
fn get_pads(node: &Node) -> Vec<usize> {
let input_dim = match &node.inputs.first().unwrap().ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Only tensor input is valid"),
_ => panic!("Pad: Only tensor input is valid"),
};
let pads: Vec<usize> = match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape
.iter()
.map(|&x| {
if x < 0 {
// TODO: support negative pads
panic!("Negative pad is not supported");
panic!("Pad: Negative pad is not supported");
}
x as usize
})
.collect(),
_ => panic!("Pads data type must be int64"),
_ => panic!("Pad: pads data type must be int64"),
};

assert_eq!(pads.len(), input_dim * 2);
Expand All @@ -778,13 +783,22 @@ pub fn pad_config(node: &Node) -> PadConfig {
vec![left, right, top, bottom]
}
}
fn get_constant_value(node: &Node) -> f32 {
// TODO: support int, boolean
if node.inputs.len() < 3 {
0.
} else {
match &node.inputs[2].value {
Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(),
_ => 0.,
}
}
}

check_inputs(node);
let pads = get_pads(node);
// TODO: support int, boolean
let constant_value: f32 = match &node.inputs[2].value {
Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(),
_ => 0.,
};
let constant_value = get_constant_value(node);

PadConfig::new(pads, constant_value)
}

Expand Down

0 comments on commit 08cf24f

Please sign in to comment.