Skip to content

Commit

Permalink
Add more checks and panics
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 15, 2024
1 parent 02b5260 commit 40f2a78
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
49 changes: 32 additions & 17 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,43 +748,58 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads(node: &Node) -> Vec<usize> {
if node.inputs.len() != 3 {
panic!("Pad: must provide three inputs")
}

let input_dim = match &node.inputs.first().unwrap().ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Only tensor input is valid"),
_ => panic!("Pad: Only tensor input is valid"),
};

let pads: Vec<usize> = match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape
.iter()
.map(|&x| {
if x < 0 {
// TODO: support negative pads
panic!("Negative pad is not supported");
panic!("Pad: Negative pad is not supported");
}
x as usize
})
.collect(),
_ => panic!("Pads data type must be int64"),
_ => panic!("Pad: pads data type must be int64"),
};

assert_eq!(pads.len(), input_dim * 2);
if pads.len() != input_dim * 2 {
panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]");
}
// TODO: Burn's pad should support 1D tensor
if input_dim < 2 {
panic!("Pad: input tensor should be rank 2 or higher");
}

if input_dim == 1 {
vec![pads[0], pads[1], 0, 0]
} else {
let left = pads[input_dim - 1];
let top = pads[input_dim - 2];
let right = pads[pads.len() - 1];
let bottom = pads[pads.len() - 2];
vec![left, right, top, bottom]
if input_dim > 2 {
log::warn!("Pad: padding will only be applied to the last two dimensions");
}

let left = pads[input_dim - 1];
let top = pads[input_dim - 2];
let right = pads[pads.len() - 1];
let bottom = pads[pads.len() - 2];
vec![left, right, top, bottom]
}
fn get_constant_value(node: &Node) -> f32 {
// TODO: support int, boolean
match &node.inputs[2].value {
Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(),
_ => panic!("Pad: should provide a constant value input to pad with, for example 0.0"),
}
}

let pads = get_pads(node);
// TODO: support int, boolean
let constant_value: f32 = match &node.inputs[2].value {
Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(),
_ => 0.,
};
let constant_value = get_constant_value(node);

PadConfig::new(pads, constant_value)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ where
)
}

/// Pad the tensor with the given value on the last two dimensions.
/// Pad the tensor of rank two or higher with the given value on the last two dimensions.
///
/// # Arguments
///
Expand Down

0 comments on commit 40f2a78

Please sign in to comment.