12
12
See the License for the specific language governing permissions and
13
13
limitations under the License.
14
14
*/
15
+ // tslint:disable:no-multi-spaces
15
16
import { test } from "../tools/tester" ;
16
17
import { concat , fill , grad , linspace , listDevices , multigrad ,
17
18
ones , Params , randn , range , stack , tensor , Tensor ,
@@ -342,9 +343,9 @@ test(async function api_reverse() {
342
343
assertAllEqual ( tensor ( [ 1 , 2 , 3 , 4 ] ) . reverse ( ) , [ 4 , 3 , 2 , 1 ] ) ;
343
344
344
345
const t = tensor ( [ [
345
- [ [ 0 , 1 , 2 , 3 ] ,
346
- [ 4 , 5 , 6 , 7 ] ,
347
- [ 8 , 9 , 10 , 11 ] ] ,
346
+ [ [ 0 , 1 , 2 , 3 ] ,
347
+ [ 4 , 5 , 6 , 7 ] ,
348
+ [ 8 , 9 , 10 , 11 ] ] ,
348
349
[ [ 12 , 13 , 14 , 15 ] ,
349
350
[ 16 , 17 , 18 , 19 ] ,
350
351
[ 20 , 21 , 22 , 23 ] ]
@@ -354,16 +355,16 @@ test(async function api_reverse() {
354
355
[ [ 12 , 13 , 14 , 15 ] ,
355
356
[ 16 , 17 , 18 , 19 ] ,
356
357
[ 20 , 21 , 22 , 23 ] ] ,
357
- [ [ 0 , 1 , 2 , 3 ] ,
358
- [ 4 , 5 , 6 , 7 ] ,
359
- [ 8 , 9 , 10 , 11 ] ]
358
+ [ [ 0 , 1 , 2 , 3 ] ,
359
+ [ 4 , 5 , 6 , 7 ] ,
360
+ [ 8 , 9 , 10 , 11 ] ]
360
361
] ] ) ;
361
362
assertAllEqual ( t . reverse ( [ 1 ] ) , tR1 ) ;
362
363
assertAllEqual ( t . reverse ( [ - 3 ] ) , tR1 ) ;
363
364
const tR2 = tensor ( [ [
364
- [ [ 8 , 9 , 10 , 11 ] ,
365
- [ 4 , 5 , 6 , 7 ] ,
366
- [ 0 , 1 , 2 , 3 ] ] ,
365
+ [ [ 8 , 9 , 10 , 11 ] ,
366
+ [ 4 , 5 , 6 , 7 ] ,
367
+ [ 0 , 1 , 2 , 3 ] ] ,
367
368
[ [ 20 , 21 , 22 , 23 ] ,
368
369
[ 16 , 17 , 18 , 19 ] ,
369
370
[ 12 , 13 , 14 , 15 ] ]
@@ -527,8 +528,8 @@ testDevices(async function api_onesAndZerosLike(tensor, device) {
527
528
const zeros = a . zerosLike ( ) ;
528
529
assertEqual ( ones . device , device ) ;
529
530
assertEqual ( zeros . device , device ) ;
530
- assertAllEqual ( ones , [ [ 1 , 1 , 1 ] , [ 1 , 1 , 1 ] ] ) ;
531
- assertAllEqual ( zeros , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
531
+ assertAllEqual ( ones , [ [ 1 , 1 , 1 ] , [ 1 , 1 , 1 ] ] ) ;
532
+ assertAllEqual ( zeros , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
532
533
} ) ;
533
534
534
535
test ( async function api_equal ( ) {
@@ -543,14 +544,14 @@ test(async function api_equal() {
543
544
const r = a . equal ( b ) ;
544
545
assertEqual ( r . dtype , "bool" ) ;
545
546
// TODO Allow assertAllEqual to handle boolean.
546
- assertAllEqual ( r , [ [ 1 , 0 , 1 ] , [ 0 , 1 , 0 ] ] ) ;
547
+ assertAllEqual ( r , [ [ 1 , 0 , 1 ] , [ 0 , 1 , 0 ] ] ) ;
547
548
548
549
// equal isn't differentiable but it should have the same behavior as
549
550
// autograd does.
550
551
const f = ( x , y ) => tensor ( x ) . equal ( y ) ;
551
552
const g = multigrad ( f , [ 0 , 1 ] ) ;
552
- assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
553
- assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
553
+ assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
554
+ assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
554
555
} ) ;
555
556
556
557
test ( async function api_greater ( ) {
@@ -565,13 +566,13 @@ test(async function api_greater() {
565
566
const r = a . greater ( b ) ;
566
567
assertEqual ( r . dtype , "bool" ) ;
567
568
// TODO Allow assertAllEqual to handle boolean.
568
- assertAllEqual ( r , [ [ 0 , 1 , 0 ] , [ 1 , 0 , 0 ] ] ) ;
569
+ assertAllEqual ( r , [ [ 0 , 1 , 0 ] , [ 1 , 0 , 0 ] ] ) ;
569
570
// greater isn't differentiable but it should have the same behavior as
570
571
// autograd does.
571
572
const f = ( x , y ) => tensor ( x ) . greater ( y ) ;
572
573
const g = multigrad ( f , [ 0 , 1 ] ) ;
573
- assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
574
- assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
574
+ assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
575
+ assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
575
576
} ) ;
576
577
577
578
test ( async function api_greaterEqual ( ) {
@@ -586,13 +587,13 @@ test(async function api_greaterEqual() {
586
587
const r = a . greaterEqual ( b ) ;
587
588
assertEqual ( r . dtype , "bool" ) ;
588
589
// TODO Allow assertAllEqual to handle boolean.
589
- assertAllEqual ( r , [ [ 1 , 1 , 1 ] , [ 1 , 1 , 0 ] ] ) ;
590
+ assertAllEqual ( r , [ [ 1 , 1 , 1 ] , [ 1 , 1 , 0 ] ] ) ;
590
591
// greaterEqual isn't differentiable but it should have the same behavior as
591
592
// autograd does.
592
593
const f = ( x , y ) => tensor ( x ) . greaterEqual ( y ) ;
593
594
const g = multigrad ( f , [ 0 , 1 ] ) ;
594
- assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
595
- assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
595
+ assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
596
+ assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
596
597
} ) ;
597
598
598
599
test ( async function api_less ( ) {
@@ -607,13 +608,13 @@ test(async function api_less() {
607
608
const r = a . less ( b ) ;
608
609
assertEqual ( r . dtype , "bool" ) ;
609
610
// TODO Allow assertAllEqual to handle boolean.
610
- assertAllEqual ( r , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 1 ] ] ) ;
611
+ assertAllEqual ( r , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 1 ] ] ) ;
611
612
// less isn't differentiable but it should have the same behavior as
612
613
// autograd does.
613
614
const f = ( x , y ) => tensor ( x ) . less ( y ) ;
614
615
const g = multigrad ( f , [ 0 , 1 ] ) ;
615
- assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
616
- assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
616
+ assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
617
+ assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
617
618
} ) ;
618
619
619
620
test ( async function api_lessEqual ( ) {
@@ -628,13 +629,13 @@ test(async function api_lessEqual() {
628
629
const r = a . lessEqual ( b ) ;
629
630
assertEqual ( r . dtype , "bool" ) ;
630
631
// TODO Allow assertAllEqual to handle boolean.
631
- assertAllEqual ( r , [ [ 1 , 0 , 1 ] , [ 0 , 1 , 1 ] ] ) ;
632
+ assertAllEqual ( r , [ [ 1 , 0 , 1 ] , [ 0 , 1 , 1 ] ] ) ;
632
633
// lessEqual isn't differentiable but it should have the same behavior as
633
634
// autograd does.
634
635
const f = ( x , y ) => tensor ( x ) . lessEqual ( y ) ;
635
636
const g = multigrad ( f , [ 0 , 1 ] ) ;
636
- assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
637
- assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
637
+ assertAllEqual ( g ( a , b ) [ 0 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
638
+ assertAllEqual ( g ( a , b ) [ 1 ] , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
638
639
} ) ;
639
640
640
641
test ( async function api_select ( ) {
@@ -643,7 +644,7 @@ test(async function api_select() {
643
644
[ 4 , 5 , 6 ] ,
644
645
] ) ;
645
646
const f = tensor ( [
646
- [ 7 , 8 , 9 ] ,
647
+ [ 7 , 8 , 9 ] ,
647
648
[ 10 , 11 , 12 ] ,
648
649
] ) ;
649
650
// TODO Use false/true literals instead of 0 and 1 in cond.
@@ -653,12 +654,12 @@ test(async function api_select() {
653
654
] , { dtype : "bool" } ) ;
654
655
const r = cond . select ( t , f ) ;
655
656
assertAllEqual ( r , [
656
- [ 1 , 8 , 3 ] ,
657
+ [ 1 , 8 , 3 ] ,
657
658
[ 10 , 5 , 12 ] ,
658
659
] ) ;
659
660
// select isn't differentiable.
660
661
const g = grad ( ( c ) => c . select ( t , f ) ) ;
661
- assertAllEqual ( g ( cond ) , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
662
+ assertAllEqual ( g ( cond ) , [ [ 0 , 0 , 0 ] , [ 0 , 0 , 0 ] ] ) ;
662
663
663
664
function f2 ( x ) {
664
665
x = tensor ( x ) ;
@@ -715,8 +716,8 @@ testDevices(async function api_pad(tensor, device) {
715
716
const padded2 = d2 . pad ( [ [ 1 , 2 ] , [ 0 , 0 ] ] , 42 ) ;
716
717
assertAllEqual ( padded2 , [
717
718
[ 42 , 42 , 42 ] ,
718
- [ 9 , 5 , 7 ] ,
719
- [ 6 , 8 , 4 ] ,
719
+ [ 9 , 5 , 7 ] ,
720
+ [ 6 , 8 , 4 ] ,
720
721
[ 42 , 42 , 42 ] ,
721
722
[ 42 , 42 , 42 ] ,
722
723
] ) ;
@@ -1002,8 +1003,8 @@ testDevices(async function api_gather(tensor, device) {
1002
1003
[ 1 , 2 , 3 , 4 ] ,
1003
1004
] ) ;
1004
1005
assertAllEqual ( t . gather ( [ 2 , 0 ] , 1 ) , [
1005
- [ 3 , 1 ] ,
1006
- [ 7 , 5 ] ,
1006
+ [ 3 , 1 ] ,
1007
+ [ 7 , 5 ] ,
1007
1008
[ 11 , 9 ]
1008
1009
] ) ;
1009
1010
} ) ;
@@ -1068,7 +1069,7 @@ testDevices(async function api_oneHot(tensor, device) {
1068
1069
1069
1070
const b = tensor ( [ 0 , 1 , 3 , 4 ] , { dtype : "int32" } ) ;
1070
1071
assertAllEqual ( b . oneHot ( 5 , 0.5 , - 0.5 ) , [
1071
- [ 0.5 , - 0.5 , - 0.5 , - 0.5 , - 0.5 ] ,
1072
+ [ 0.5 , - 0.5 , - 0.5 , - 0.5 , - 0.5 ] ,
1072
1073
[ - 0.5 , 0.5 , - 0.5 , - 0.5 , - 0.5 ] ,
1073
1074
[ - 0.5 , - 0.5 , - 0.5 , 0.5 , - 0.5 ] ,
1074
1075
[ - 0.5 , - 0.5 , - 0.5 , - 0.5 , 0.5 ] ,
@@ -1094,9 +1095,9 @@ test(async function api_softmaxCE() {
1094
1095
assertAllClose ( ce , [ 12.00034142 , 8.00034142 , 3.6003418 ] ) ;
1095
1096
const g = grad ( f ) ;
1096
1097
assertAllClose ( g ( logits ) , [
1097
- [ - 9.99993861e-01 , 3.35348042e-04 , 9.99658465e-01 ] ,
1098
- [ 6.14211376e-06 , - 9.99664664e-01 , 9.99658465e-01 ] ,
1099
- [ - 2.99993873e-01 , 3.35348042e-04 , 2.99658477e-01 ]
1098
+ [ - 9.99993861e-01 , 3.35348042e-04 , 9.99658465e-01 ] ,
1099
+ [ 6.14211376e-06 , - 9.99664664e-01 , 9.99658465e-01 ] ,
1100
+ [ - 2.99993873e-01 , 3.35348042e-04 , 2.99658477e-01 ]
1100
1101
] ) ;
1101
1102
} ) ;
1102
1103
@@ -1129,7 +1130,7 @@ testDevices(async function api_neuralNet(tensor, device) {
1129
1130
const inference = ( params : Params , images : Tensor ) => {
1130
1131
let inputs = images . cast ( "float32" ) . div ( 255 ) . reshape ( [ - 1 , 28 * 28 ] ) ;
1131
1132
let outputs ;
1132
- const layerSizes = [ 28 * 28 , 64 , 10 ] ;
1133
+ const layerSizes = [ 28 * 28 , 64 , 10 ] ;
1133
1134
for ( let i = 0 ; i < layerSizes . length - 1 ; ++ i ) {
1134
1135
const m = layerSizes [ i ] ;
1135
1136
const n = layerSizes [ i + 1 ] ;
@@ -1258,10 +1259,10 @@ test(async function api_conv2d() {
1258
1259
assertShapesEqual ( g_ [ 0 ] . shape , img . shape ) ;
1259
1260
assertShapesEqual ( g_ [ 1 ] . shape , filter . shape ) ;
1260
1261
assertAllEqual ( g_ [ 0 ] . squeeze ( ) , [
1261
- [ 0 , 1 , 1 , 1 ] ,
1262
- [ 2 , 6 , 6 , 4 ] ,
1263
- [ 2 , 6 , 6 , 4 ] ,
1264
- [ 2 , 5 , 5 , 3 ] ,
1262
+ [ 0 , 1 , 1 , 1 ] ,
1263
+ [ 2 , 6 , 6 , 4 ] ,
1264
+ [ 2 , 6 , 6 , 4 ] ,
1265
+ [ 2 , 5 , 5 , 3 ] ,
1265
1266
] ) ;
1266
1267
assertAllEqual ( g_ [ 1 ] . squeeze ( ) , [ [ 45 , 54 ] , [ 81 , 90 ] ] ) ;
1267
1268
} ) ;
@@ -1279,10 +1280,10 @@ test(async function api_maxPool() {
1279
1280
const gx = g ( x ) ;
1280
1281
assertShapesEqual ( gx . shape , x . shape ) ;
1281
1282
assertAllEqual ( gx . squeeze ( ) , [
1282
- [ 0 , 0 , 0 , 0 ] ,
1283
- [ 0 , 1 , 0 , 1 ] ,
1284
- [ 0 , 0 , 0 , 0 ] ,
1285
- [ 0 , 1 , 0 , 1 ] ,
1283
+ [ 0 , 0 , 0 , 0 ] ,
1284
+ [ 0 , 1 , 0 , 1 ] ,
1285
+ [ 0 , 0 , 0 , 0 ] ,
1286
+ [ 0 , 1 , 0 , 1 ] ,
1286
1287
] ) ;
1287
1288
} ) ;
1288
1289
@@ -1329,8 +1330,8 @@ test(async function api_size() {
1329
1330
1330
1331
test ( async function api_stopGradientSwallowedErr ( ) {
1331
1332
function loss ( params ) {
1332
- const a = api . zeros ( [ 5 ] ) ;
1333
- const b = api . zeros ( [ 11 ] ) ;
1333
+ const a = api . zeros ( [ 5 ] ) ;
1334
+ const b = api . zeros ( [ 11 ] ) ;
1334
1335
b . stopGradient ( ) ;
1335
1336
assert ( ! shapesEqual ( a . shape , b . shape ) ) ;
1336
1337
// Because the shapes aren't equal, they should throw error when added
0 commit comments