Skip to content

Commit 38840e8

Browse files
committed
iterator in NDArrayPhp
1 parent 01bfc77 commit 38840e8

File tree

4 files changed

+752
-11
lines changed

4 files changed

+752
-11
lines changed

src/LinearAlgebra.php

+219
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ public function __construct($blas,$lapack,$math,$defaultFloatType=null)
2727
$this->defaultFloatType = $defaultFloatType;
2828
}
2929

30+
protected function printableShapes($values)
31+
{
32+
if(!is_array($values)) {
33+
if($values instanceof NDArray)
34+
return '('.implode(',',$values->shape()).')';
35+
if(is_object($values))
36+
return '"'.get_class($values).'"';
37+
if(is_numeric($values) || is_string($values))
38+
return strval($values);
39+
return gettype($values);
40+
}
41+
$string = '[';
42+
foreach($values as $value) {
43+
if($string!='[') {
44+
$string .= ',';
45+
}
46+
$string .= $this->printableShapes($value);
47+
}
48+
$string .= ']';
49+
return $string;
50+
}
51+
3052
public function alloc(array $shape,$dtype=null)
3153
{
3254
if($dtype===null)
@@ -383,6 +405,118 @@ public function gemm(
383405
return $C;
384406
}
385407

408+
/**
409+
*
410+
*/
411+
public function matmul(
412+
NDArray $A,
413+
NDArray $B,
414+
bool $transA=null,
415+
bool $transB=null,
416+
NDArray $C=null,
417+
float $alpha=null,
418+
float $beta=null
419+
) : NDArray
420+
{
421+
if($A->ndim()<2 || $B->ndim()<2) {
422+
throw new InvalidArgumentException('Dimensions rank must be greater then 2D or equal:['.
423+
implode(',',$A->shape()).']<=>['.implode(',',$B->shape()).']');
424+
}
425+
$shapeA = $A->shape();
426+
$shapeB = $B->shape();
427+
$shapeEA = [array_pop($shapeA)];
428+
array_unshift($shapeEA,array_pop($shapeA));
429+
$shapeEB = [array_pop($shapeB)];
430+
array_unshift($shapeEB,array_pop($shapeB));
431+
$batchA = (int)array_product($shapeA);
432+
$batchB = (int)array_product($shapeB);
433+
$flatA = $A->reshape(array_merge([$batchA],$shapeEA));
434+
$flatB = $B->reshape(array_merge([$batchB],$shapeEB));
435+
436+
if($transA) {
437+
$shapeEA = array_reverse($shapeEA);
438+
}
439+
if($transB) {
440+
$shapeEB = array_reverse($shapeEB);
441+
}
442+
if($shapeEA[1]!=$shapeEB[0]) {
443+
throw new InvalidArgumentException('The number of columns in "A" and the number of rows in "B" must be the same:['.
444+
implode(',',$A->shape()).']<=>['.implode(',',$B->shape()).']');
445+
}
446+
447+
$AA = $A->buffer();
448+
$BB = $B->buffer();
449+
$M = $shapeEA[0];
450+
$N = $shapeEB[1];
451+
$K = $shapeEA[1];
452+
453+
if($alpha===null) {
454+
$alpha = 1.0;
455+
}
456+
if($beta===null) {
457+
$beta = 0.0;
458+
}
459+
$lda = ($transA) ? $M : $K;
460+
$ldb = ($transB) ? $K : $N;
461+
$ldc = $N;
462+
$transA = ($transA) ? BLAS::Trans : BLAS::NoTrans;
463+
$transB = ($transB) ? BLAS::Trans : BLAS::NoTrans;
464+
465+
$shapeEC = [$shapeEA[0],$shapeEB[1]];
466+
if($batchA>$batchB) {
467+
$broadcastDest = $batchA;
468+
$broadcastBase = $batchB;
469+
$orgShapeC=array_merge($shapeA,$shapeEC);
470+
} else {
471+
$broadcastDest = $batchB;
472+
$broadcastBase = $batchA;
473+
$orgShapeC=array_merge($shapeB,$shapeEC);
474+
}
475+
if($broadcastDest % $broadcastBase != 0) {
476+
throw new InvalidArgumentException('Matrix size-incompatible for broadcast:['.
477+
implode(',',$A->shape()).']<=>['.implode(',',$B->shape()).']');
478+
}
479+
if($C!=null) {
480+
if($C->shape()!=$orgShapeC) {
481+
throw new InvalidArgumentException('"A" and "C" must have the same number of rows."B" and "C" must have the same number of columns:['.
482+
implode(',',$A->shape()).'] , ['.implode(',',$B->shape()).'] => ['.implode(',',$C->shape()).']');
483+
}
484+
} else {
485+
$C = $this->alloc($orgShapeC,$A->dtype());
486+
$this->zeros($C);
487+
}
488+
$flatC = $C->reshape(array_merge([$broadcastDest],$shapeEC));
489+
$CC = $C->buffer();
490+
$repeats = (int)floor($broadcastDest/$broadcastBase);
491+
$offA = $A->offset();
492+
$offB = $B->offset();
493+
$offC = $C->offset();
494+
$incA = $M*$K;
495+
$incB = $N*$K;
496+
$incC = $M*$N;
497+
for($i=0;$i<$repeats;$i++) {
498+
if($batchA>$batchB) {
499+
$offB = $B->offset();
500+
} else {
501+
$offA = $A->offset();
502+
}
503+
for($j=0;$j<$broadcastBase;$j++) {
504+
$this->blas->gemm(
505+
BLAS::RowMajor,$transA,$transB,
506+
$M,$N,$K,
507+
$alpha,
508+
$AA,$offA,$lda,
509+
$BB,$offB,$ldb,
510+
$beta,
511+
$CC,$offC,$ldc);
512+
$offA+=$incA;
513+
$offB+=$incB;
514+
$offC+=$incC;
515+
}
516+
}
517+
return $C;
518+
}
519+
386520
/**
387521
* ret := x_1 + ... + x_n
388522
*/
@@ -2057,6 +2191,91 @@ public function stack(
20572191
return $output;
20582192
}
20592193

2194+
public function concat(
2195+
array $values,
2196+
int $axis=null
2197+
) : NDArray
2198+
{
2199+
if($axis===null) {
2200+
$axis = -1;
2201+
}
2202+
if($axis<0) {
2203+
$axis = $values[0]->ndim() + $axis;
2204+
}
2205+
$m = null;
2206+
$base = null;
2207+
$n = 0;
2208+
$reshapeValues = [];
2209+
foreach ($values as $value) {
2210+
$shapePrefix = [];
2211+
$shape = $value->shape();
2212+
$mm = 1;
2213+
for($j=0;$j<$axis;$j++) {
2214+
$mmm = array_shift($shape);
2215+
$shapePrefix[] = $mmm;
2216+
$mm *= $mmm;
2217+
}
2218+
$nn = array_shift($shape);
2219+
if($base===null) {
2220+
$m = $mm;
2221+
$base = $shape;
2222+
} else {
2223+
if($m!=$mm||$base!=$shape) {
2224+
throw new InvalidArgumentException('Unmatch shape: '.
2225+
$this->printableShapes($values));
2226+
}
2227+
}
2228+
$n += $nn;
2229+
$reshapeValues[] = $value->reshape(array_merge([$mm,$nn],$shape));
2230+
}
2231+
$dims = $shape;
2232+
$shape = array_merge([$m,$n],$shape);
2233+
$output = $this->alloc($shape,$values[0]->dtype());
2234+
$i = 0;
2235+
foreach ($reshapeValues as $value) {
2236+
$nn = $value->shape()[1];
2237+
$this->doSlice(true,
2238+
$output,
2239+
[0,$i],[-1,$nn],
2240+
$value
2241+
);
2242+
$i += $nn;
2243+
}
2244+
$output = $output->reshape(array_merge($shapePrefix,[$n],$dims));
2245+
return $output;
2246+
}
2247+
2248+
public function split(
2249+
NDArray $input, array $sizeSplits, $axis=null
2250+
) : array
2251+
{
2252+
if($axis===null) {
2253+
$axis = -1;
2254+
}
2255+
if($axis<0) {
2256+
$axis = $input->ndim() + $axis;
2257+
}
2258+
$shapePrefix = [];
2259+
$shape = $input->shape();
2260+
$m = 1;
2261+
for($j=0;$j<$axis;$j++) {
2262+
$mmm = array_shift($shape);
2263+
$shapePrefix[] = $mmm;
2264+
$m *= $mmm;
2265+
}
2266+
$n = array_shift($shape);
2267+
$input = $input->reshape(array_merge([$m,$n],$shape));
2268+
$i = 0;
2269+
foreach ($sizeSplits as $size) {
2270+
$outputs[] = $this->doSlice(false,
2271+
$input,
2272+
[0,$i],[-1,$size]
2273+
)->reshape(array_merge($shapePrefix,[$size],$shape));
2274+
$i += $size;
2275+
}
2276+
return $outputs;
2277+
}
2278+
20602279
protected function doSlice(
20612280
bool $reverse,
20622281
NDArray $input,

src/NDArrayPhp.php

+20-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
use SplFixedArray;
55
use ArrayObject;
66
use ArrayAccess;
7+
use Countable;
8+
use IteratorAggregate;
79
use InvalidArgumentException;
810
use OutOfRangeException;
911
use LogicException;
@@ -13,7 +15,7 @@
1315
use Interop\Polite\Math\Matrix\BLAS;
1416
use Interop\Polite\Math\Matrix\NDArray;
1517

16-
class NDArrayPhp implements NDArray,Serializable
18+
class NDArrayPhp implements NDArray,Serializable,Countable,IteratorAggregate
1719
{
1820
protected $_shape;
1921
protected $_buffer;
@@ -199,8 +201,10 @@ public function size() : int
199201
public function reshape(array $shape) : NDArray
200202
{
201203
$this->assertShape($shape);
202-
if($this->size()!=array_product($shape))
203-
throw new InvalidArgumentException("Unmatch size");
204+
if($this->size()!=array_product($shape)) {
205+
throw new InvalidArgumentException("Unmatch size to reshape: ".
206+
"[".implode(',',$this->shape())."]=>[".implode(',',$shape)."]");
207+
}
204208
$newArray = new self($this->buffer(),$this->dtype(),$shape,$this->offset());
205209
return $newArray;
206210
}
@@ -305,6 +309,19 @@ public function offsetUnset( $offset )
305309
throw new LogicException("Unsuppored Operation");
306310
}
307311

312+
public function count()
313+
{
314+
return $this->_shape[0];
315+
}
316+
317+
public function getIterator()
318+
{
319+
$count = $this->_shape[0];
320+
for($i=0;$i<$count;$i++) {
321+
yield $i => $this->offsetGet($i);
322+
}
323+
}
324+
308325
public function setPortableSerializeMode(bool $mode)
309326
{
310327
$this->_portableSerializeMode = $mode ? true : false;

0 commit comments

Comments
 (0)