77// except according to those terms.
88
99use crate :: dimension:: BroadcastShape ;
10- use crate :: data_traits:: MaybeUninitSubst ;
1110use crate :: Zip ;
1211use num_complex:: Complex ;
1312
@@ -68,8 +67,8 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
6867where
6968 A : Clone + $trt<B , Output =A >,
7069 B : Clone ,
71- S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst < A > ,
72- < S as MaybeUninitSubst < A >> :: Output : DataMut ,
70+ S : DataOwned <Elem =A > + DataMut ,
71+ S :: MaybeUninit : DataMut ,
7372 S2 : Data <Elem =B >,
7473 D : Dimension + BroadcastShape <E >,
7574 E : Dimension ,
@@ -96,38 +95,24 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
9695where
9796 A : Clone + $trt<B , Output =A >,
9897 B : Clone ,
99- S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst < A > ,
100- < S as MaybeUninitSubst < A >> :: Output : DataMut ,
98+ S : DataOwned <Elem =A > + DataMut ,
99+ S :: MaybeUninit : DataMut ,
101100 S2 : Data <Elem =B >,
102101 D : Dimension + BroadcastShape <E >,
103102 E : Dimension ,
104103{
105104 type Output = ArrayBase <S , <D as BroadcastShape <E >>:: Output >;
106105 fn $mth( self , rhs: & ArrayBase <S2 , E >) -> Self :: Output
107106 {
108- let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
109- if shape. slice( ) == self . dim. slice( ) {
107+ if self . ndim( ) == rhs. ndim( ) && self . shape( ) == rhs. shape( ) {
110108 let mut out = self . into_dimensionality:: <<D as BroadcastShape <E >>:: Output >( ) . unwrap( ) ;
111- out. zip_mut_with( rhs, |x, y| {
112- * x = x. clone( ) $operator y. clone( ) ;
113- } ) ;
109+ out. zip_mut_with_same_shape( rhs, clone_iopf( A :: $mth) ) ;
114110 out
115111 } else {
112+ let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
116113 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
117- let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
118- // SAFETY: Overwrite all the elements in the array after
119- // it is created via `raw_view_mut`.
120- unsafe {
121- let mut out =ArrayBase :: <<S as MaybeUninitSubst <A >>:: Output , <D as BroadcastShape <E >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
122- let output_view = out. raw_view_mut( ) . cast:: <A >( ) ;
123- Zip :: from( & lhs) . and( & rhs)
124- . and( output_view)
125- . collect_with_partial( |x, y| {
126- x. clone( ) $operator y. clone( )
127- } )
128- . release_ownership( ) ;
129- out. assume_init( )
130- }
114+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
115+ Zip :: from( & lhs) . and( & rhs) . map_collect_owned( clone_opf( A :: $mth) )
131116 }
132117 }
133118}
@@ -148,38 +133,24 @@ where
148133 A : Clone + $trt<B , Output =B >,
149134 B : Clone ,
150135 S : Data <Elem =A >,
151- S2 : DataOwned <Elem =B > + DataMut + MaybeUninitSubst < B > ,
152- < S2 as MaybeUninitSubst < B >> :: Output : DataMut ,
136+ S2 : DataOwned <Elem =B > + DataMut ,
137+ S2 :: MaybeUninit : DataMut ,
153138 D : Dimension ,
154139 E : Dimension + BroadcastShape <D >,
155140{
156141 type Output = ArrayBase <S2 , <E as BroadcastShape <D >>:: Output >;
157142 fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
158143 where
159144 {
160- let shape = rhs. dim. broadcast_shape( & self . dim) . unwrap( ) ;
161- if shape. slice( ) == rhs. dim. slice( ) {
145+ if self . ndim( ) == rhs. ndim( ) && self . shape( ) == rhs. shape( ) {
162146 let mut out = rhs. into_dimensionality:: <<E as BroadcastShape <D >>:: Output >( ) . unwrap( ) ;
163- out. zip_mut_with( self , |x, y| {
164- * x = y. clone( ) $operator x. clone( ) ;
165- } ) ;
147+ out. zip_mut_with_same_shape( self , clone_iopf_rev( A :: $mth) ) ;
166148 out
167149 } else {
150+ let shape = rhs. dim. broadcast_shape( & self . dim) . unwrap( ) ;
168151 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
169- let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
170- // SAFETY: Overwrite all the elements in the array after
171- // it is created via `raw_view_mut`.
172- unsafe {
173- let mut out =ArrayBase :: <<S2 as MaybeUninitSubst <B >>:: Output , <E as BroadcastShape <D >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
174- let output_view = out. raw_view_mut( ) . cast:: <B >( ) ;
175- Zip :: from( & lhs) . and( & rhs)
176- . and( output_view)
177- . collect_with_partial( |x, y| {
178- x. clone( ) $operator y. clone( )
179- } )
180- . release_ownership( ) ;
181- out. assume_init( )
182- }
152+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
153+ Zip :: from( & lhs) . and( & rhs) . map_collect_owned( clone_opf( A :: $mth) )
183154 }
184155 }
185156}
@@ -207,8 +178,7 @@ where
207178 let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
208179 let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
209180 let rhs = rhs. broadcast( shape) . unwrap( ) ;
210- let out = Zip :: from( & lhs) . and( & rhs) . map_collect( |x, y| x. clone( ) $operator y. clone( ) ) ;
211- out
181+ Zip :: from( & lhs) . and( & rhs) . map_collect( clone_opf( A :: $mth) )
212182 }
213183}
214184
@@ -313,6 +283,18 @@ mod arithmetic_ops {
313283 use num_complex:: Complex ;
314284 use std:: ops:: * ;
315285
286+ fn clone_opf < A : Clone , B : Clone , C > ( f : impl Fn ( A , B ) -> C ) -> impl FnMut ( & A , & B ) -> C {
287+ move |x, y| f ( x. clone ( ) , y. clone ( ) )
288+ }
289+
290+ fn clone_iopf < A : Clone , B : Clone > ( f : impl Fn ( A , B ) -> A ) -> impl FnMut ( & mut A , & B ) {
291+ move |x, y| * x = f ( x. clone ( ) , y. clone ( ) )
292+ }
293+
294+ fn clone_iopf_rev < A : Clone , B : Clone > ( f : impl Fn ( A , B ) -> B ) -> impl FnMut ( & mut B , & A ) {
295+ move |x, y| * x = f ( y. clone ( ) , x. clone ( ) )
296+ }
297+
316298 impl_binary_op ! ( Add , +, add, +=, "addition" ) ;
317299 impl_binary_op ! ( Sub , -, sub, -=, "subtraction" ) ;
318300 impl_binary_op ! ( Mul , * , mul, *=, "multiplication" ) ;
0 commit comments