Skip to content

Commit f498a13

Browse files
committed
Math slice functions supports axis 2
1 parent 685686e commit f498a13

File tree

6 files changed

+632
-225
lines changed

6 files changed

+632
-225
lines changed

composer.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
"interop-phpobjects/polite-math": ">=1.0.4"
1010
},
1111
"suggest": {
12-
"ext-rindow_openblas": "0.2.2 strongly recommended for speed",
13-
"ext-rindow_opencl": "0.1.1 GPU/OpenCL support",
12+
"ext-rindow_openblas": "0.2.3 strongly recommended for speed",
13+
"ext-rindow_opencl": "0.1.3 GPU/OpenCL support",
1414
"ext-rindow_clblast": "0.1.2 BLAS on GPU/OpenCL support",
1515
"rindow/math-plot": "for OpenCL tunning"
1616
},

src/LinearAlgebra.php

+103-29
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,35 @@ public function stack(
25942594
);
25952595
$i++;
25962596
}
2597+
} elseif($axis==2){
2598+
$k = count($values);
2599+
$shape = $values[0]->shape();
2600+
$m = array_shift($shape);
2601+
$n = array_shift($shape);
2602+
array_unshift($shape,$k);
2603+
array_unshift($shape,$n);
2604+
array_unshift($shape,$m);
2605+
$output = $this->alloc($shape,$values[0]->dtype());
2606+
$i = 0;
2607+
foreach($values as $value){
2608+
if(!($value instanceof NDArray)) {
2609+
throw new InvalidArgumentException('values must be array of NDArray');
2610+
}
2611+
$shape = $value->shape();
2612+
$m = array_shift($shape);
2613+
$n = array_shift($shape);
2614+
array_unshift($shape,1);
2615+
array_unshift($shape,$n);
2616+
array_unshift($shape,$m);
2617+
$value = $value->reshape(
2618+
$shape);
2619+
$this->doSlice(true,
2620+
$output,
2621+
[0,0,$i],[-1,-1,1],
2622+
$value
2623+
);
2624+
$i++;
2625+
}
25972626
} else {
25982627
throw new InvalidArgumentException('unsuppoted axis');
25992628
}
@@ -2618,12 +2647,10 @@ public function concat(
26182647
foreach ($values as $value) {
26192648
$shapePrefix = [];
26202649
$shape = $value->shape();
2621-
$mm = 1;
26222650
for($j=0;$j<$axis;$j++) {
2623-
$mmm = array_shift($shape);
2624-
$shapePrefix[] = $mmm;
2625-
$mm *= $mmm;
2651+
$shapePrefix[] = array_shift($shape);
26262652
}
2653+
$mm = (int)array_product($shapePrefix);
26272654
$nn = array_shift($shape);
26282655
if($base===null) {
26292656
$m = $mm;
@@ -2666,12 +2693,10 @@ public function split(
26662693
}
26672694
$shapePrefix = [];
26682695
$shape = $input->shape();
2669-
$m = 1;
26702696
for($j=0;$j<$axis;$j++) {
2671-
$mmm = array_shift($shape);
2672-
$shapePrefix[] = $mmm;
2673-
$m *= $mmm;
2697+
$shapePrefix[] = array_shift($shape);
26742698
}
2699+
$m = (int)array_product($shapePrefix);
26752700
$n = array_shift($shape);
26762701
$input = $input->reshape(array_merge([$m,$n],$shape));
26772702
$i = 0;
@@ -2701,12 +2726,12 @@ protected function doSlice(
27012726
$orgBegin = $begin;
27022727
$orgSize = $size;
27032728
$ndimBegin = count($begin);
2704-
if($ndimBegin<1||$ndimBegin>2) {
2705-
throw new InvalidArgumentException('begin must has 1 or 2 integer.');
2729+
if($ndimBegin<1||$ndimBegin>3) {
2730+
throw new InvalidArgumentException('begin must has 1 or 2 or 3 integer.');
27062731
}
27072732
$ndimSize = count($size);
2708-
if($ndimSize<1||$ndimSize>2) {
2709-
throw new InvalidArgumentException('Size must has 1 or 2 integer.');
2733+
if($ndimSize<1||$ndimSize>3) {
2734+
throw new InvalidArgumentException('Size must has 1 or 2 or 3 integer.');
27102735
}
27112736
if($ndimBegin!=$ndimSize){
27122737
throw new InvalidArgumentException('Unmatch shape of begin and size');
@@ -2716,6 +2741,8 @@ protected function doSlice(
27162741
throw new InvalidArgumentException($messageInput.' shape rank is low to slice');
27172742
}
27182743
$shape = $input->shape();
2744+
2745+
// ndim = 0
27192746
$m = array_shift($shape);
27202747
$startAxis0 = array_shift($begin);
27212748
if($startAxis0<0){
@@ -2731,7 +2758,9 @@ protected function doSlice(
27312758
if($sizeAxis0<1||$startAxis0+$sizeAxis0>$m){
27322759
throw new InvalidArgumentException('size of axis 0 is invalid value.');
27332760
}
2734-
if($ndimBegin==1){
2761+
2762+
// ndim = 1
2763+
if($ndimBegin<=1){
27352764
$n = 1;
27362765
$startAxis1 = 0;
27372766
$sizeAxis1 = 1;
@@ -2752,19 +2781,48 @@ protected function doSlice(
27522781
throw new InvalidArgumentException('size of axis 1 is invalid value.');
27532782
}
27542783
}
2755-
$k = array_product($shape);
2784+
2785+
// ndim = 2
2786+
if($ndimBegin<=2){
2787+
$k = 1;
2788+
$startAxis2 = 0;
2789+
$sizeAxis2 = 1;
2790+
} else {
2791+
$k = array_shift($shape);
2792+
$startAxis2 = array_shift($begin);
2793+
if($startAxis2<0){
2794+
$startAxis2 = $k+$startAxis2;
2795+
}
2796+
if($startAxis2<0||$startAxis2>=$k){
2797+
throw new InvalidArgumentException('start of axis 2 is invalid value.:begin=['.implode(',',$orgBegin).']');
2798+
}
2799+
$sizeAxis2 = array_shift($size);
2800+
if($sizeAxis2<0){
2801+
$sizeAxis2 = $k-$startAxis2+$sizeAxis2+1;
2802+
}
2803+
if($sizeAxis2<1||$startAxis2+$sizeAxis2>$k){
2804+
throw new InvalidArgumentException('size of axis 2 is invalid value.');
2805+
}
2806+
}
2807+
$itemSize = array_product($shape);
27562808
$outputShape = [$sizeAxis0];
2757-
if($ndimBegin==2){
2809+
if($ndimBegin>=2){
27582810
array_push($outputShape,
27592811
$sizeAxis1);
27602812
}
2813+
if($ndimBegin>=3){
2814+
array_push($outputShape,
2815+
$sizeAxis2);
2816+
}
27612817
$outputShape = array_merge(
27622818
$outputShape,$shape);
27632819
if($output==null){
27642820
$output = $this->alloc($outputShape,$input->dtype());
27652821
}else{
27662822
if($outputShape!=$output->shape()){
2767-
throw new InvalidArgumentException('Unmatch output shape');
2823+
throw new InvalidArgumentException('Unmatch output shape: '.
2824+
$this->printableShapes($outputShape).'<=>'.
2825+
$this->printableShapes($output->shape()));
27682826
}
27692827
}
27702828

@@ -2780,10 +2838,12 @@ protected function doSlice(
27802838
$m,
27812839
$n,
27822840
$k,
2841+
$itemSize,
27832842
$A,$offsetA,$incA,
27842843
$Y,$offsetY,$incY,
27852844
$startAxis0,$sizeAxis0,
2786-
$startAxis1,$sizeAxis1
2845+
$startAxis1,$sizeAxis1,
2846+
$startAxis2,$sizeAxis2
27872847
);
27882848
return $output;
27892849
}
@@ -2805,13 +2865,16 @@ public function repeat(NDArray $A, int $repeats)
28052865
$B = $this->alloc($shape,$A->dtype());
28062866
$m = $s1;
28072867
$n = $repeats;
2808-
$k = (int)array_product($shapeCell);
2868+
$k = 1;
2869+
$size = (int)array_product($shapeCell);
28092870
$AA = $A->buffer();
28102871
$offA = $A->offset();
28112872
$BB = $B->buffer();
28122873
$offB = $B->offset();
28132874
$startAxis0 = 0;
28142875
$sizeAxis0 = $m;
2876+
$startAxis2 = 0;
2877+
$sizeAxis2 = 1;
28152878
for($i=0;$i<$repeats;$i++) {
28162879
$startAxis1 = $i;
28172880
$sizeAxis1 = 1;
@@ -2821,10 +2884,12 @@ public function repeat(NDArray $A, int $repeats)
28212884
$m,
28222885
$n,
28232886
$k,
2887+
$size,
28242888
$BB,$offB,1,
28252889
$AA,$offA,1,
28262890
$startAxis0,$sizeAxis0,
2827-
$startAxis1,$sizeAxis1
2891+
$startAxis1,$sizeAxis1,
2892+
$startAxis2,$sizeAxis2
28282893
);
28292894
}
28302895
return $B;
@@ -2985,20 +3050,29 @@ public function numericalGradient(
29853050
$this->zeros($grad);
29863051
$grads[] = $grad;
29873052
$size = $x->size();
2988-
$xx = $x->buffer();
2989-
$idx = $x->offset();
2990-
$gg = $grad->buffer();
2991-
$gidx = $grad->offset();
3053+
$xx = $x->reshape([$x->size()]);
3054+
//$idx = $x->offset();
3055+
$gg = $grad->reshape([$grad->size()]);
3056+
//$gidx = $grad->offset();
29923057
$h2 = $h*2 ;
2993-
for($i=0;$i<$size;$i++,$idx++,$gidx++) {
2994-
$value = $xx[$idx];
2995-
$xx[$idx] = $value + $h;
3058+
for($i=0;$i<$size;$i++) {
3059+
// $value = $xx[$idx];
3060+
$value = $this->copy($xx[[$i,$i]]);
3061+
// $xx[$idx] = $value + $h;
3062+
$this->copy($this->increment($this->copy($value),$h),$xx[[$i,$i]]);
3063+
//echo $value[0]."-h =>".$xx[$i]."\n";
29963064
$y1 = $f(...$variables);
2997-
$xx[$idx] = $value - $h;
3065+
// $xx[$idx] = $value - $h;
3066+
$this->copy($this->increment($this->copy($value),-$h),$xx[[$i,$i]]);
3067+
//echo $value[0]."-h =>".$xx[$i]."\n";
29983068
$y2 = $f(...$variables);
29993069
$d = $this->axpy($y2,$this->copy($y1),-1);
3000-
$gg[$gidx] = $this->sum($d)/$h2;
3001-
$xx[$idx] = $value;
3070+
// $gg[$gidx] = $this->sum($d)/$h2;
3071+
$sum = $this->reduceSum($d->reshape([$d->size(),1]));
3072+
//echo "d=".$sum[0]."\n";
3073+
$this->copy($this->scal(1/$h2,$sum),$gg[[$i,$i]]);
3074+
// $xx[$idx] = $value;
3075+
$this->copy($value,$xx[[$i,$i]]);
30023076
}
30033077
}
30043078
return $grads;

0 commit comments

Comments
 (0)