Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize arithmetic ops to more combinations of scalars and arrays #782

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,22 @@ fn scalar_add_2(bench: &mut test::Bencher) {
bench.iter(|| n + &a);
}

#[bench]
fn scalar_add_strided_1(bench: &mut test::Bencher) {
let a =
Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]);
let n = 1.;
bench.iter(|| &a + n);
}

#[bench]
fn scalar_add_strided_2(bench: &mut test::Bencher) {
let a =
Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]);
let n = 1.;
bench.iter(|| n + &a);
}

#[bench]
fn scalar_sub_1(bench: &mut test::Bencher) {
let a = Array::<f32, _>::zeros((64, 64));
Expand Down
144 changes: 65 additions & 79 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,70 +152,56 @@ impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
#[doc=$doc]
/// between the reference `self` and the scalar `x`,
/// and return the result as a new `Array`.
impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
where A: Clone + $trt<B, Output=A>,
impl<'a, A, S, D, B, C> $trt<B> for &'a ArrayBase<S, D>
where A: Clone + $trt<B, Output=C>,
S: Data<Elem=A>,
D: Dimension,
B: ScalarOperand,
{
type Output = Array<A, D>;
fn $mth(self, x: B) -> Array<A, D> {
self.to_owned().$mth(x)
type Output = Array<C, D>;
fn $mth(self, x: B) -> Self::Output {
self.map(move |elt| elt.clone() $operator x.clone())
}
}
);
);

// Pick the expression $a for commutative and $b for ordered binop
macro_rules! if_commutative {
(Commute { $a:expr } or { $b:expr }) => {
$a
};
(Ordered { $a:expr } or { $b:expr }) => {
$b
};
}

macro_rules! impl_scalar_lhs_op {
// $commutative flag. Reuse the self + scalar impl if we can.
// We can do this safely since these are the primitive numeric types
($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
// these have no doc -- they are not visible in rustdoc
// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result (based on `self`).
impl<S, D> $trt<ArrayBase<S, D>> for $scalar
where S: DataOwned<Elem=$scalar> + DataMut,
D: Dimension,
($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
/// Perform elementwise
#[doc=$doc]
/// between the scalar `self` and array `rhs`,
/// and return the result (based on `self`).
impl<A, S, D> $trt<ArrayBase<S, D>> for $scalar
where
$scalar: Clone + $trt<A, Output=A>,
A: Clone,
S: DataOwned<Elem=A> + DataMut,
D: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {{
let mut rhs = rhs;
rhs.unordered_foreach_mut(move |elt| {
*elt = self $operator *elt;
});
rhs
}})
fn $mth(self, mut rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
rhs.unordered_foreach_mut(move |elt| {
*elt = self.clone() $operator elt.clone();
});
rhs
}
}

// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result as a new `Array`.
impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
where S: Data<Elem=$scalar>,
D: Dimension,
/// Perform elementwise
#[doc=$doc]
/// between the scalar `self` and array `rhs`,
/// and return the result as a new `Array`.
impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar
where
$scalar: Clone + $trt<A, Output=B>,
A: Clone,
S: Data<Elem=A>,
D: Dimension,
Copy link
Member

@bluss bluss Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This impl somehow now breaks Rust -- see the failed tests -- and causes a recursion errror - for an expression that has type f32 + f32 which is quite strange/scary(!)

   --> tests/oper.rs:159:48
    |
159 |         .fold(f32::zero(), |acc, (&x, &y)| acc + x * y)
    |                                                ^
    |
    = help: consider adding a `#![recursion_limit="256"]` attribute to your crate (`oper`)
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
    = note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`

Unsure if this is a Rust bug - for example that the impl is accepted(?), but I think this impl is too general and has infinite descent.

Given the question if f32 implements Add<&ArrayBase<S, D>> look for other impl that has f32: Add<A> where S: Data<Elem=A> which looks recursive, is that it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a compiler bug to me. As you point out, the expression involves only f32, but for some reason, the error message indicates that one of the arguments is an array. It's also interesting that on my machine with Rust 1.48.0, the error message is slightly different, saying "impl of Add<ndarray::ArrayBase<_, _>> for f32" instead of the error message in your comment "impl of Add<&ndarray::ArrayBase<_, _>> for f32". (Note the &.)

The function fails to compile (with the same error message) even after adding type annotations:

fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
where
    V1: AsArray<'a, f32>,
    V2: AsArray<'a, f32>,
{
    let a: ArrayView1<'a, f32> = a.into();
    let b: ArrayView1<'a, f32> = b.into();
    a.iter()
        .zip(b.iter())
        .fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc + x * y)
}

but if I remove the + x * y, it compiles successfully:

fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
where
    V1: AsArray<'a, f32>,
    V2: AsArray<'a, f32>,
{
    let a: ArrayView1<'a, f32> = a.into();
    let b: ArrayView1<'a, f32> = b.into();
    a.iter()
        .zip(b.iter())
        .fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc)
}

I don't see any reason other than a compiler bug for the first function to fail to compile when the second one compiles without errors, since the type annotations confirm that the closure is operating only on f32 values.

This also compiles successfully:

fn reference_dot2<'a>(a: ArrayView1<'a, f32>, b: ArrayView1<'a, f32>) -> f32 {
    a.iter()
        .zip(b.iter())
        .fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc + x * y)
}

so the bug involves the .into() calls in some way. It's surprising that adding explicit type annotations for the results of the .into() calls, as in the first example, doesn't work around the bug.

Fwiw, I don't think impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar is infinitely recursive, since AFAIK it's not possible to have an array of (arrays of (arrays of (arrays of ... [infinite depth]))). The innermost array type can only have an element type that's not an array. You're right that there is recursion if you're dealing with arrays of arrays, but that's the correct behavior, and the recursion is not infinite.

For the particular function we're looking at, the impl doesn't apply, and I don't think the compiler should be trying to apply it. (I think it should only apply the impl if it knows the RHS has some type &ArrayBase<?S, ?D>, where ?S and ?D are inference variables.)

Copy link
Member

@bluss bluss Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, the test runners for cross_test, stable, mips vs i686 disagree with each other about the error too, in the same way, even if they both use Rust 1.48

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reported the issue (with a simplified example) at rust-lang/rust#80542.

{
type Output = Array<$scalar, D>;
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<$scalar, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {
self.$mth(rhs.to_owned())
})
type Output = Array<B, D>;
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<B, D> {
rhs.map(move |elt| self.clone() $operator elt.clone())
}
}
);
Expand All @@ -241,16 +227,16 @@ mod arithmetic_ops {

macro_rules! all_scalar_ops {
($int_scalar:ty) => (
impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition");
impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!($int_scalar, /, Div, div, "division");
impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder");
impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift");
impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift");
);
}
all_scalar_ops!(i8);
Expand All @@ -264,31 +250,31 @@ mod arithmetic_ops {
all_scalar_ops!(i128);
all_scalar_ops!(u128);

impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor");

impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!(f32, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f32, /, Div, div, "division");
impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder");

impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!(f64, +, Add, add, "addition");
impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f64, /, Div, div, "division");
impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder");

impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(Complex<f32>, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f32>, /, Div, div, "division");

impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(Complex<f64>, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f64>, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f64>, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f64>, /, Div, div, "division");

impl<A, S, D> Neg for ArrayBase<S, D>
where
Expand Down