@@ -6,11 +6,11 @@ use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
66///
77/// Uses the [NumPy broadcasting rules]
88// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9- fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10- where
11- D1 : Dimension ,
12- D2 : Dimension ,
13- Output : Dimension ,
9+ pub ( crate ) fn co_broadcast < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10+ where
11+ D1 : Dimension ,
12+ D2 : Dimension ,
13+ Output : Dimension ,
1414{
1515 let ( k, overflow) = shape1. ndim ( ) . overflowing_sub ( shape2. ndim ( ) ) ;
1616 // Swap the order if d2 is longer.
@@ -37,40 +37,23 @@ fn co_broadcast<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, Shap
3737pub trait DimMax < Other : Dimension > {
3838 /// The resulting dimension type after broadcasting.
3939 type Output : Dimension ;
40-
41- /// Determines the shape after broadcasting the shapes together.
42- ///
43- /// If the shapes are not compatible, returns `Err`.
44- fn broadcast_shape ( & self , other : & Other ) -> Result < Self :: Output , ShapeError > ;
4540}
4641
4742/// Dimensions of the same type remain unchanged when co_broadcast.
4843/// So you can directly use D as the resulting type.
4944/// (Instead of <D as DimMax<D>>::BroadcastOutput)
5045impl < D : Dimension > DimMax < D > for D {
5146 type Output = D ;
52-
53- fn broadcast_shape ( & self , other : & D ) -> Result < Self :: Output , ShapeError > {
54- co_broadcast :: < D , D , Self :: Output > ( self , other)
55- }
5647}
5748
5849macro_rules! impl_broadcast_distinct_fixed {
5950 ( $smaller: ty, $larger: ty) => {
6051 impl DimMax <$larger> for $smaller {
6152 type Output = $larger;
62-
63- fn broadcast_shape( & self , other: & $larger) -> Result <Self :: Output , ShapeError > {
64- co_broadcast:: <Self , $larger, Self :: Output >( self , other)
65- }
6653 }
6754
6855 impl DimMax <$smaller> for $larger {
6956 type Output = $larger;
70-
71- fn broadcast_shape( & self , other: & $smaller) -> Result <Self :: Output , ShapeError > {
72- co_broadcast:: <Self , $smaller, Self :: Output >( self , other)
73- }
7457 }
7558 } ;
7659}
@@ -103,3 +86,57 @@ impl_broadcast_distinct_fixed!(Ix3, IxDyn);
10386impl_broadcast_distinct_fixed ! ( Ix4 , IxDyn ) ;
10487impl_broadcast_distinct_fixed ! ( Ix5 , IxDyn ) ;
10588impl_broadcast_distinct_fixed ! ( Ix6 , IxDyn ) ;
89+
90+
91+ #[ cfg( test) ]
92+ mod tests {
93+ use super :: co_broadcast;
94+ use crate :: { Dimension , Dim , DimMax , ShapeError , Ix0 , IxDynImpl , ErrorKind } ;
95+
96+ #[ test]
97+ fn test_broadcast_shape ( ) {
98+ fn test_co < D1 , D2 > (
99+ d1 : & D1 ,
100+ d2 : & D2 ,
101+ r : Result < <D1 as DimMax < D2 > >:: Output , ShapeError > ,
102+ ) where
103+ D1 : Dimension + DimMax < D2 > ,
104+ D2 : Dimension ,
105+ {
106+ let d = co_broadcast :: < D1 , D2 , <D1 as DimMax < D2 > >:: Output > ( & d1, d2) ;
107+ assert_eq ! ( d, r) ;
108+ }
109+ test_co ( & Dim ( [ 2 , 3 ] ) , & Dim ( [ 4 , 1 , 3 ] ) , Ok ( Dim ( [ 4 , 2 , 3 ] ) ) ) ;
110+ test_co (
111+ & Dim ( [ 1 , 2 , 2 ] ) ,
112+ & Dim ( [ 1 , 3 , 4 ] ) ,
113+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
114+ ) ;
115+ test_co ( & Dim ( [ 3 , 4 , 5 ] ) , & Ix0 ( ) , Ok ( Dim ( [ 3 , 4 , 5 ] ) ) ) ;
116+ let v = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ;
117+ test_co (
118+ & Dim ( vec ! [ 1 , 1 , 3 , 1 , 5 , 1 , 7 ] ) ,
119+ & Dim ( [ 2 , 1 , 4 , 1 , 6 , 1 ] ) ,
120+ Ok ( Dim ( IxDynImpl :: from ( v. as_slice ( ) ) ) ) ,
121+ ) ;
122+ let d = Dim ( [ 1 , 2 , 1 , 3 ] ) ;
123+ test_co ( & d, & d, Ok ( d) ) ;
124+ test_co (
125+ & Dim ( [ 2 , 1 , 2 ] ) . into_dyn ( ) ,
126+ & Dim ( 0 ) ,
127+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
128+ ) ;
129+ test_co (
130+ & Dim ( [ 2 , 1 , 1 ] ) ,
131+ & Dim ( [ 0 , 0 , 1 , 3 , 4 ] ) ,
132+ Ok ( Dim ( [ 0 , 0 , 2 , 3 , 4 ] ) ) ,
133+ ) ;
134+ test_co ( & Dim ( [ 0 ] ) , & Dim ( [ 0 , 0 , 0 ] ) , Ok ( Dim ( [ 0 , 0 , 0 ] ) ) ) ;
135+ test_co ( & Dim ( 1 ) , & Dim ( [ 1 , 0 , 0 ] ) , Ok ( Dim ( [ 1 , 0 , 0 ] ) ) ) ;
136+ test_co (
137+ & Dim ( [ 1 , 3 , 0 , 1 , 1 ] ) ,
138+ & Dim ( [ 1 , 2 , 3 , 1 ] ) ,
139+ Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ,
140+ ) ;
141+ }
142+ }
0 commit comments