@@ -2594,6 +2594,35 @@ public function stack(
2594
2594
);
2595
2595
$ i ++;
2596
2596
}
2597
+ } elseif ($ axis ==2 ){
2598
+ $ k = count ($ values );
2599
+ $ shape = $ values [0 ]->shape ();
2600
+ $ m = array_shift ($ shape );
2601
+ $ n = array_shift ($ shape );
2602
+ array_unshift ($ shape ,$ k );
2603
+ array_unshift ($ shape ,$ n );
2604
+ array_unshift ($ shape ,$ m );
2605
+ $ output = $ this ->alloc ($ shape ,$ values [0 ]->dtype ());
2606
+ $ i = 0 ;
2607
+ foreach ($ values as $ value ){
2608
+ if (!($ value instanceof NDArray)) {
2609
+ throw new InvalidArgumentException ('values must be array of NDArray ' );
2610
+ }
2611
+ $ shape = $ value ->shape ();
2612
+ $ m = array_shift ($ shape );
2613
+ $ n = array_shift ($ shape );
2614
+ array_unshift ($ shape ,1 );
2615
+ array_unshift ($ shape ,$ n );
2616
+ array_unshift ($ shape ,$ m );
2617
+ $ value = $ value ->reshape (
2618
+ $ shape );
2619
+ $ this ->doSlice (true ,
2620
+ $ output ,
2621
+ [0 ,0 ,$ i ],[-1 ,-1 ,1 ],
2622
+ $ value
2623
+ );
2624
+ $ i ++;
2625
+ }
2597
2626
} else {
2598
2627
throw new InvalidArgumentException ('unsuppoted axis ' );
2599
2628
}
@@ -2618,12 +2647,10 @@ public function concat(
2618
2647
foreach ($ values as $ value ) {
2619
2648
$ shapePrefix = [];
2620
2649
$ shape = $ value ->shape ();
2621
- $ mm = 1 ;
2622
2650
for ($ j =0 ;$ j <$ axis ;$ j ++) {
2623
- $ mmm = array_shift ($ shape );
2624
- $ shapePrefix [] = $ mmm ;
2625
- $ mm *= $ mmm ;
2651
+ $ shapePrefix [] = array_shift ($ shape );
2626
2652
}
2653
+ $ mm = (int )array_product ($ shapePrefix );
2627
2654
$ nn = array_shift ($ shape );
2628
2655
if ($ base ===null ) {
2629
2656
$ m = $ mm ;
@@ -2666,12 +2693,10 @@ public function split(
2666
2693
}
2667
2694
$ shapePrefix = [];
2668
2695
$ shape = $ input ->shape ();
2669
- $ m = 1 ;
2670
2696
for ($ j =0 ;$ j <$ axis ;$ j ++) {
2671
- $ mmm = array_shift ($ shape );
2672
- $ shapePrefix [] = $ mmm ;
2673
- $ m *= $ mmm ;
2697
+ $ shapePrefix [] = array_shift ($ shape );
2674
2698
}
2699
+ $ m = (int )array_product ($ shapePrefix );
2675
2700
$ n = array_shift ($ shape );
2676
2701
$ input = $ input ->reshape (array_merge ([$ m ,$ n ],$ shape ));
2677
2702
$ i = 0 ;
@@ -2701,12 +2726,12 @@ protected function doSlice(
2701
2726
$ orgBegin = $ begin ;
2702
2727
$ orgSize = $ size ;
2703
2728
$ ndimBegin = count ($ begin );
2704
- if ($ ndimBegin <1 ||$ ndimBegin >2 ) {
2705
- throw new InvalidArgumentException ('begin must has 1 or 2 integer. ' );
2729
+ if ($ ndimBegin <1 ||$ ndimBegin >3 ) {
2730
+ throw new InvalidArgumentException ('begin must has 1 or 2 or 3 integer. ' );
2706
2731
}
2707
2732
$ ndimSize = count ($ size );
2708
- if ($ ndimSize <1 ||$ ndimSize >2 ) {
2709
- throw new InvalidArgumentException ('Size must has 1 or 2 integer. ' );
2733
+ if ($ ndimSize <1 ||$ ndimSize >3 ) {
2734
+ throw new InvalidArgumentException ('Size must has 1 or 2 or 3 integer. ' );
2710
2735
}
2711
2736
if ($ ndimBegin !=$ ndimSize ){
2712
2737
throw new InvalidArgumentException ('Unmatch shape of begin and size ' );
@@ -2716,6 +2741,8 @@ protected function doSlice(
2716
2741
throw new InvalidArgumentException ($ messageInput .' shape rank is low to slice ' );
2717
2742
}
2718
2743
$ shape = $ input ->shape ();
2744
+
2745
+ // ndim = 0
2719
2746
$ m = array_shift ($ shape );
2720
2747
$ startAxis0 = array_shift ($ begin );
2721
2748
if ($ startAxis0 <0 ){
@@ -2731,7 +2758,9 @@ protected function doSlice(
2731
2758
if ($ sizeAxis0 <1 ||$ startAxis0 +$ sizeAxis0 >$ m ){
2732
2759
throw new InvalidArgumentException ('size of axis 0 is invalid value. ' );
2733
2760
}
2734
- if ($ ndimBegin ==1 ){
2761
+
2762
+ // ndim = 1
2763
+ if ($ ndimBegin <=1 ){
2735
2764
$ n = 1 ;
2736
2765
$ startAxis1 = 0 ;
2737
2766
$ sizeAxis1 = 1 ;
@@ -2752,19 +2781,48 @@ protected function doSlice(
2752
2781
throw new InvalidArgumentException ('size of axis 1 is invalid value. ' );
2753
2782
}
2754
2783
}
2755
- $ k = array_product ($ shape );
2784
+
2785
+ // ndim = 2
2786
+ if ($ ndimBegin <=2 ){
2787
+ $ k = 1 ;
2788
+ $ startAxis2 = 0 ;
2789
+ $ sizeAxis2 = 1 ;
2790
+ } else {
2791
+ $ k = array_shift ($ shape );
2792
+ $ startAxis2 = array_shift ($ begin );
2793
+ if ($ startAxis2 <0 ){
2794
+ $ startAxis2 = $ k +$ startAxis2 ;
2795
+ }
2796
+ if ($ startAxis2 <0 ||$ startAxis2 >=$ k ){
2797
+ throw new InvalidArgumentException ('start of axis 2 is invalid value.:begin=[ ' .implode (', ' ,$ orgBegin ).'] ' );
2798
+ }
2799
+ $ sizeAxis2 = array_shift ($ size );
2800
+ if ($ sizeAxis2 <0 ){
2801
+ $ sizeAxis2 = $ k -$ startAxis2 +$ sizeAxis2 +1 ;
2802
+ }
2803
+ if ($ sizeAxis2 <1 ||$ startAxis2 +$ sizeAxis2 >$ k ){
2804
+ throw new InvalidArgumentException ('size of axis 2 is invalid value. ' );
2805
+ }
2806
+ }
2807
+ $ itemSize = array_product ($ shape );
2756
2808
$ outputShape = [$ sizeAxis0 ];
2757
- if ($ ndimBegin= =2 ){
2809
+ if ($ ndimBegin> =2 ){
2758
2810
array_push ($ outputShape ,
2759
2811
$ sizeAxis1 );
2760
2812
}
2813
+ if ($ ndimBegin >=3 ){
2814
+ array_push ($ outputShape ,
2815
+ $ sizeAxis2 );
2816
+ }
2761
2817
$ outputShape = array_merge (
2762
2818
$ outputShape ,$ shape );
2763
2819
if ($ output ==null ){
2764
2820
$ output = $ this ->alloc ($ outputShape ,$ input ->dtype ());
2765
2821
}else {
2766
2822
if ($ outputShape !=$ output ->shape ()){
2767
- throw new InvalidArgumentException ('Unmatch output shape ' );
2823
+ throw new InvalidArgumentException ('Unmatch output shape: ' .
2824
+ $ this ->printableShapes ($ outputShape ).'<=> ' .
2825
+ $ this ->printableShapes ($ output ->shape ()));
2768
2826
}
2769
2827
}
2770
2828
@@ -2780,10 +2838,12 @@ protected function doSlice(
2780
2838
$ m ,
2781
2839
$ n ,
2782
2840
$ k ,
2841
+ $ itemSize ,
2783
2842
$ A ,$ offsetA ,$ incA ,
2784
2843
$ Y ,$ offsetY ,$ incY ,
2785
2844
$ startAxis0 ,$ sizeAxis0 ,
2786
- $ startAxis1 ,$ sizeAxis1
2845
+ $ startAxis1 ,$ sizeAxis1 ,
2846
+ $ startAxis2 ,$ sizeAxis2
2787
2847
);
2788
2848
return $ output ;
2789
2849
}
@@ -2805,13 +2865,16 @@ public function repeat(NDArray $A, int $repeats)
2805
2865
$ B = $ this ->alloc ($ shape ,$ A ->dtype ());
2806
2866
$ m = $ s1 ;
2807
2867
$ n = $ repeats ;
2808
- $ k = (int )array_product ($ shapeCell );
2868
+ $ k = 1 ;
2869
+ $ size = (int )array_product ($ shapeCell );
2809
2870
$ AA = $ A ->buffer ();
2810
2871
$ offA = $ A ->offset ();
2811
2872
$ BB = $ B ->buffer ();
2812
2873
$ offB = $ B ->offset ();
2813
2874
$ startAxis0 = 0 ;
2814
2875
$ sizeAxis0 = $ m ;
2876
+ $ startAxis2 = 0 ;
2877
+ $ sizeAxis2 = 1 ;
2815
2878
for ($ i =0 ;$ i <$ repeats ;$ i ++) {
2816
2879
$ startAxis1 = $ i ;
2817
2880
$ sizeAxis1 = 1 ;
@@ -2821,10 +2884,12 @@ public function repeat(NDArray $A, int $repeats)
2821
2884
$ m ,
2822
2885
$ n ,
2823
2886
$ k ,
2887
+ $ size ,
2824
2888
$ BB ,$ offB ,1 ,
2825
2889
$ AA ,$ offA ,1 ,
2826
2890
$ startAxis0 ,$ sizeAxis0 ,
2827
- $ startAxis1 ,$ sizeAxis1
2891
+ $ startAxis1 ,$ sizeAxis1 ,
2892
+ $ startAxis2 ,$ sizeAxis2
2828
2893
);
2829
2894
}
2830
2895
return $ B ;
@@ -2985,20 +3050,29 @@ public function numericalGradient(
2985
3050
$ this ->zeros ($ grad );
2986
3051
$ grads [] = $ grad ;
2987
3052
$ size = $ x ->size ();
2988
- $ xx = $ x ->buffer ( );
2989
- $ idx = $ x ->offset ();
2990
- $ gg = $ grad ->buffer ( );
2991
- $ gidx = $ grad ->offset ();
3053
+ $ xx = $ x ->reshape ([ $ x -> size ()] );
3054
+ // $idx = $x->offset();
3055
+ $ gg = $ grad ->reshape ([ $ grad -> size ()] );
3056
+ // $gidx = $grad->offset();
2992
3057
$ h2 = $ h *2 ;
2993
- for ($ i =0 ;$ i <$ size ;$ i ++,$ idx ++,$ gidx ++) {
2994
- $ value = $ xx [$ idx ];
2995
- $ xx [$ idx ] = $ value + $ h ;
3058
+ for ($ i =0 ;$ i <$ size ;$ i ++) {
3059
+ // $value = $xx[$idx];
3060
+ $ value = $ this ->copy ($ xx [[$ i ,$ i ]]);
3061
+ // $xx[$idx] = $value + $h;
3062
+ $ this ->copy ($ this ->increment ($ this ->copy ($ value ),$ h ),$ xx [[$ i ,$ i ]]);
3063
+ //echo $value[0]."-h =>".$xx[$i]."\n";
2996
3064
$ y1 = $ f (...$ variables );
2997
- $ xx [$ idx ] = $ value - $ h ;
3065
+ // $xx[$idx] = $value - $h;
3066
+ $ this ->copy ($ this ->increment ($ this ->copy ($ value ),-$ h ),$ xx [[$ i ,$ i ]]);
3067
+ //echo $value[0]."-h =>".$xx[$i]."\n";
2998
3068
$ y2 = $ f (...$ variables );
2999
3069
$ d = $ this ->axpy ($ y2 ,$ this ->copy ($ y1 ),-1 );
3000
- $ gg [$ gidx ] = $ this ->sum ($ d )/$ h2 ;
3001
- $ xx [$ idx ] = $ value ;
3070
+ // $gg[$gidx] = $this->sum($d)/$h2;
3071
+ $ sum = $ this ->reduceSum ($ d ->reshape ([$ d ->size (),1 ]));
3072
+ //echo "d=".$sum[0]."\n";
3073
+ $ this ->copy ($ this ->scal (1 /$ h2 ,$ sum ),$ gg [[$ i ,$ i ]]);
3074
+ // $xx[$idx] = $value;
3075
+ $ this ->copy ($ value ,$ xx [[$ i ,$ i ]]);
3002
3076
}
3003
3077
}
3004
3078
return $ grads ;
0 commit comments