@@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
66///
77/// Uses the [NumPy broadcasting rules]
88// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9- fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10- where
11- D1 : Dimension ,
12- D2 : Dimension ,
13- Output : Dimension ,
9+ pub ( crate ) fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10+ where
11+ D1 : Dimension ,
12+ D2 : Dimension ,
13+ Output : Dimension ,
1414{
1515 let ( k, overflow) = shape1. ndim ( ) . overflowing_sub ( shape2. ndim ( ) ) ;
1616 // Swap the order if d2 is longer.
@@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
3737pub trait DimMax < Other : Dimension > {
3838 /// The resulting dimension type after broadcasting.
3939 type Output : Dimension ;
40-
41- /// Determines the shape after broadcasting the shapes together.
42- ///
43- /// If the shapes are not compatible, returns `Err`.
44- fn broadcast_shape ( & self , other : & Other ) -> Result < Self :: Output , ShapeError > ;
4540}
4641
4742/// Dimensions of the same type remain unchanged when co_broadcast.
4843/// So you can directly use D as the resulting type.
4944/// (Instead of <D as DimMax<D>>::BroadcastOutput)
5045impl < D : Dimension > DimMax < D > for D {
5146 type Output = D ;
52-
53- fn broadcast_shape ( & self , other : & D ) -> Result < Self :: Output , ShapeError > {
54- co_broadcast :: < D , D , Self :: Output > ( self , other)
55- }
5647}
5748
5849macro_rules! impl_broadcast_distinct_fixed {
5950 ( $smaller: ty, $larger: ty) => {
6051 impl DimMax <$larger> for $smaller {
6152 type Output = $larger;
62-
63- fn broadcast_shape( & self , other: & $larger) -> Result <Self :: Output , ShapeError > {
64- co_broadcast:: <Self , $larger, Self :: Output >( self , other)
65- }
6653 }
6754
6855 impl DimMax <$smaller> for $larger {
6956 type Output = $larger;
70-
71- fn broadcast_shape( & self , other: & $smaller) -> Result <Self :: Output , ShapeError > {
72- co_broadcast:: <Self , $smaller, Self :: Output >( self , other)
73- }
7457 }
7558 } ;
7659}
@@ -103,3 +86,58 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
10386impl_broadcast_distinct_fixed ! ( Ix4 , IxDyn ) ;
10487impl_broadcast_distinct_fixed ! ( Ix5 , IxDyn ) ;
10588impl_broadcast_distinct_fixed ! ( Ix6 , IxDyn ) ;
89+
90+
91+ #[ cfg( test) ]
92+ #[ cfg( feature = "std" ) ]
93+ mod tests {
94+ use super :: co_broadcast;
95+ use crate :: { Dimension , Dim , DimMax , ShapeError , Ix0 , IxDynImpl , ErrorKind } ;
96+
97+ #[ test]
98+ fn test_broadcast_shape ( ) {
99+ fn test_co < D1 , D2 > (
100+ d1 : & D1 ,
101+ d2 : & D2 ,
102+ r : Result < <D1 as DimMax < D2 > >:: Output , ShapeError > ,
103+ ) where
104+ D1 : Dimension + DimMax < D2 > ,
105+ D2 : Dimension ,
106+ {
107+ let d = co_broadcast :: < D1 , D2 , <D1 as DimMax < D2 > >:: Output > ( & d1, d2) ;
108+ assert_eq ! ( d, r) ;
109+ }
110+ test_co ( & Dim ( [ 2 , 3 ] ) , & Dim ( [ 4 , 1 , 3 ] ) , Ok ( Dim ( [ 4 , 2 , 3 ] ) ) ) ;
111+ test_co (
112+ & Dim ( [ 1 , 2 , 2 ] ) ,
113+ & Dim ( [ 1 , 3 , 4 ] ) ,
114+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
115+ ) ;
116+ test_co ( & Dim ( [ 3 , 4 , 5 ] ) , & Ix0 ( ) , Ok ( Dim ( [ 3 , 4 , 5 ] ) ) ) ;
117+ let v = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ;
118+ test_co (
119+ & Dim ( vec ! [ 1 , 1 , 3 , 1 , 5 , 1 , 7 ] ) ,
120+ & Dim ( [ 2 , 1 , 4 , 1 , 6 , 1 ] ) ,
121+ Ok ( Dim ( IxDynImpl :: from ( v. as_slice ( ) ) ) ) ,
122+ ) ;
123+ let d = Dim ( [ 1 , 2 , 1 , 3 ] ) ;
124+ test_co ( & d, & d, Ok ( d) ) ;
125+ test_co (
126+ & Dim ( [ 2 , 1 , 2 ] ) . into_dyn ( ) ,
127+ & Dim ( 0 ) ,
128+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
129+ ) ;
130+ test_co (
131+ & Dim ( [ 2 , 1 , 1 ] ) ,
132+ & Dim ( [ 0 , 0 , 1 , 3 , 4 ] ) ,
133+ Ok ( Dim ( [ 0 , 0 , 2 , 3 , 4 ] ) ) ,
134+ ) ;
135+ test_co ( & Dim ( [ 0 ] ) , & Dim ( [ 0 , 0 , 0 ] ) , Ok ( Dim ( [ 0 , 0 , 0 ] ) ) ) ;
136+ test_co ( & Dim ( 1 ) , & Dim ( [ 1 , 0 , 0 ] ) , Ok ( Dim ( [ 1 , 0 , 0 ] ) ) ) ;
137+ test_co (
138+ & Dim ( [ 1 , 3 , 0 , 1 , 1 ] ) ,
139+ & Dim ( [ 1 , 2 , 3 , 1 ] ) ,
140+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
141+ ) ;
142+ }
143+ }
0 commit comments