@@ -27,6 +27,28 @@ public function __construct($blas,$lapack,$math,$defaultFloatType=null)
27
27
$ this ->defaultFloatType = $ defaultFloatType ;
28
28
}
29
29
30
+ protected function printableShapes ($ values )
31
+ {
32
+ if (!is_array ($ values )) {
33
+ if ($ values instanceof NDArray)
34
+ return '( ' .implode (', ' ,$ values ->shape ()).') ' ;
35
+ if (is_object ($ values ))
36
+ return '" ' .get_class ($ values ).'" ' ;
37
+ if (is_numeric ($ values ) || is_string ($ values ))
38
+ return strval ($ values );
39
+ return gettype ($ values );
40
+ }
41
+ $ string = '[ ' ;
42
+ foreach ($ values as $ value ) {
43
+ if ($ string !='[ ' ) {
44
+ $ string .= ', ' ;
45
+ }
46
+ $ string .= $ this ->printableShapes ($ value );
47
+ }
48
+ $ string .= '] ' ;
49
+ return $ string ;
50
+ }
51
+
30
52
public function alloc (array $ shape ,$ dtype =null )
31
53
{
32
54
if ($ dtype ===null )
@@ -383,6 +405,118 @@ public function gemm(
383
405
return $ C ;
384
406
}
385
407
408
+ /**
409
+ *
410
+ */
411
+ public function matmul (
412
+ NDArray $ A ,
413
+ NDArray $ B ,
414
+ bool $ transA =null ,
415
+ bool $ transB =null ,
416
+ NDArray $ C =null ,
417
+ float $ alpha =null ,
418
+ float $ beta =null
419
+ ) : NDArray
420
+ {
421
+ if ($ A ->ndim ()<2 || $ B ->ndim ()<2 ) {
422
+ throw new InvalidArgumentException ('Dimensions rank must be greater then 2D or equal:[ ' .
423
+ implode (', ' ,$ A ->shape ()).']<=>[ ' .implode (', ' ,$ B ->shape ()).'] ' );
424
+ }
425
+ $ shapeA = $ A ->shape ();
426
+ $ shapeB = $ B ->shape ();
427
+ $ shapeEA = [array_pop ($ shapeA )];
428
+ array_unshift ($ shapeEA ,array_pop ($ shapeA ));
429
+ $ shapeEB = [array_pop ($ shapeB )];
430
+ array_unshift ($ shapeEB ,array_pop ($ shapeB ));
431
+ $ batchA = (int )array_product ($ shapeA );
432
+ $ batchB = (int )array_product ($ shapeB );
433
+ $ flatA = $ A ->reshape (array_merge ([$ batchA ],$ shapeEA ));
434
+ $ flatB = $ B ->reshape (array_merge ([$ batchB ],$ shapeEB ));
435
+
436
+ if ($ transA ) {
437
+ $ shapeEA = array_reverse ($ shapeEA );
438
+ }
439
+ if ($ transB ) {
440
+ $ shapeEB = array_reverse ($ shapeEB );
441
+ }
442
+ if ($ shapeEA [1 ]!=$ shapeEB [0 ]) {
443
+ throw new InvalidArgumentException ('The number of columns in "A" and the number of rows in "B" must be the same:[ ' .
444
+ implode (', ' ,$ A ->shape ()).']<=>[ ' .implode (', ' ,$ B ->shape ()).'] ' );
445
+ }
446
+
447
+ $ AA = $ A ->buffer ();
448
+ $ BB = $ B ->buffer ();
449
+ $ M = $ shapeEA [0 ];
450
+ $ N = $ shapeEB [1 ];
451
+ $ K = $ shapeEA [1 ];
452
+
453
+ if ($ alpha ===null ) {
454
+ $ alpha = 1.0 ;
455
+ }
456
+ if ($ beta ===null ) {
457
+ $ beta = 0.0 ;
458
+ }
459
+ $ lda = ($ transA ) ? $ M : $ K ;
460
+ $ ldb = ($ transB ) ? $ K : $ N ;
461
+ $ ldc = $ N ;
462
+ $ transA = ($ transA ) ? BLAS ::Trans : BLAS ::NoTrans;
463
+ $ transB = ($ transB ) ? BLAS ::Trans : BLAS ::NoTrans;
464
+
465
+ $ shapeEC = [$ shapeEA [0 ],$ shapeEB [1 ]];
466
+ if ($ batchA >$ batchB ) {
467
+ $ broadcastDest = $ batchA ;
468
+ $ broadcastBase = $ batchB ;
469
+ $ orgShapeC =array_merge ($ shapeA ,$ shapeEC );
470
+ } else {
471
+ $ broadcastDest = $ batchB ;
472
+ $ broadcastBase = $ batchA ;
473
+ $ orgShapeC =array_merge ($ shapeB ,$ shapeEC );
474
+ }
475
+ if ($ broadcastDest % $ broadcastBase != 0 ) {
476
+ throw new InvalidArgumentException ('Matrix size-incompatible for broadcast:[ ' .
477
+ implode (', ' ,$ A ->shape ()).']<=>[ ' .implode (', ' ,$ B ->shape ()).'] ' );
478
+ }
479
+ if ($ C !=null ) {
480
+ if ($ C ->shape ()!=$ orgShapeC ) {
481
+ throw new InvalidArgumentException ('"A" and "C" must have the same number of rows."B" and "C" must have the same number of columns:[ ' .
482
+ implode (', ' ,$ A ->shape ()).'] , [ ' .implode (', ' ,$ B ->shape ()).'] => [ ' .implode (', ' ,$ C ->shape ()).'] ' );
483
+ }
484
+ } else {
485
+ $ C = $ this ->alloc ($ orgShapeC ,$ A ->dtype ());
486
+ $ this ->zeros ($ C );
487
+ }
488
+ $ flatC = $ C ->reshape (array_merge ([$ broadcastDest ],$ shapeEC ));
489
+ $ CC = $ C ->buffer ();
490
+ $ repeats = (int )floor ($ broadcastDest /$ broadcastBase );
491
+ $ offA = $ A ->offset ();
492
+ $ offB = $ B ->offset ();
493
+ $ offC = $ C ->offset ();
494
+ $ incA = $ M *$ K ;
495
+ $ incB = $ N *$ K ;
496
+ $ incC = $ M *$ N ;
497
+ for ($ i =0 ;$ i <$ repeats ;$ i ++) {
498
+ if ($ batchA >$ batchB ) {
499
+ $ offB = $ B ->offset ();
500
+ } else {
501
+ $ offA = $ A ->offset ();
502
+ }
503
+ for ($ j =0 ;$ j <$ broadcastBase ;$ j ++) {
504
+ $ this ->blas ->gemm (
505
+ BLAS ::RowMajor,$ transA ,$ transB ,
506
+ $ M ,$ N ,$ K ,
507
+ $ alpha ,
508
+ $ AA ,$ offA ,$ lda ,
509
+ $ BB ,$ offB ,$ ldb ,
510
+ $ beta ,
511
+ $ CC ,$ offC ,$ ldc );
512
+ $ offA +=$ incA ;
513
+ $ offB +=$ incB ;
514
+ $ offC +=$ incC ;
515
+ }
516
+ }
517
+ return $ C ;
518
+ }
519
+
386
520
/**
387
521
* ret := x_1 + ... + x_n
388
522
*/
@@ -2057,6 +2191,91 @@ public function stack(
2057
2191
return $ output ;
2058
2192
}
2059
2193
2194
+ public function concat (
2195
+ array $ values ,
2196
+ int $ axis =null
2197
+ ) : NDArray
2198
+ {
2199
+ if ($ axis ===null ) {
2200
+ $ axis = -1 ;
2201
+ }
2202
+ if ($ axis <0 ) {
2203
+ $ axis = $ values [0 ]->ndim () + $ axis ;
2204
+ }
2205
+ $ m = null ;
2206
+ $ base = null ;
2207
+ $ n = 0 ;
2208
+ $ reshapeValues = [];
2209
+ foreach ($ values as $ value ) {
2210
+ $ shapePrefix = [];
2211
+ $ shape = $ value ->shape ();
2212
+ $ mm = 1 ;
2213
+ for ($ j =0 ;$ j <$ axis ;$ j ++) {
2214
+ $ mmm = array_shift ($ shape );
2215
+ $ shapePrefix [] = $ mmm ;
2216
+ $ mm *= $ mmm ;
2217
+ }
2218
+ $ nn = array_shift ($ shape );
2219
+ if ($ base ===null ) {
2220
+ $ m = $ mm ;
2221
+ $ base = $ shape ;
2222
+ } else {
2223
+ if ($ m !=$ mm ||$ base !=$ shape ) {
2224
+ throw new InvalidArgumentException ('Unmatch shape: ' .
2225
+ $ this ->printableShapes ($ values ));
2226
+ }
2227
+ }
2228
+ $ n += $ nn ;
2229
+ $ reshapeValues [] = $ value ->reshape (array_merge ([$ mm ,$ nn ],$ shape ));
2230
+ }
2231
+ $ dims = $ shape ;
2232
+ $ shape = array_merge ([$ m ,$ n ],$ shape );
2233
+ $ output = $ this ->alloc ($ shape ,$ values [0 ]->dtype ());
2234
+ $ i = 0 ;
2235
+ foreach ($ reshapeValues as $ value ) {
2236
+ $ nn = $ value ->shape ()[1 ];
2237
+ $ this ->doSlice (true ,
2238
+ $ output ,
2239
+ [0 ,$ i ],[-1 ,$ nn ],
2240
+ $ value
2241
+ );
2242
+ $ i += $ nn ;
2243
+ }
2244
+ $ output = $ output ->reshape (array_merge ($ shapePrefix ,[$ n ],$ dims ));
2245
+ return $ output ;
2246
+ }
2247
+
2248
+ public function split (
2249
+ NDArray $ input , array $ sizeSplits , $ axis =null
2250
+ ) : array
2251
+ {
2252
+ if ($ axis ===null ) {
2253
+ $ axis = -1 ;
2254
+ }
2255
+ if ($ axis <0 ) {
2256
+ $ axis = $ input ->ndim () + $ axis ;
2257
+ }
2258
+ $ shapePrefix = [];
2259
+ $ shape = $ input ->shape ();
2260
+ $ m = 1 ;
2261
+ for ($ j =0 ;$ j <$ axis ;$ j ++) {
2262
+ $ mmm = array_shift ($ shape );
2263
+ $ shapePrefix [] = $ mmm ;
2264
+ $ m *= $ mmm ;
2265
+ }
2266
+ $ n = array_shift ($ shape );
2267
+ $ input = $ input ->reshape (array_merge ([$ m ,$ n ],$ shape ));
2268
+ $ i = 0 ;
2269
+ foreach ($ sizeSplits as $ size ) {
2270
+ $ outputs [] = $ this ->doSlice (false ,
2271
+ $ input ,
2272
+ [0 ,$ i ],[-1 ,$ size ]
2273
+ )->reshape (array_merge ($ shapePrefix ,[$ size ],$ shape ));
2274
+ $ i += $ size ;
2275
+ }
2276
+ return $ outputs ;
2277
+ }
2278
+
2060
2279
protected function doSlice (
2061
2280
bool $ reverse ,
2062
2281
NDArray $ input ,
0 commit comments