Skip to content

Commit 099615e

Browse files
committed
fix float64 on gemm
1 parent 8bd2ffd commit 099615e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/LinearAlgebra.php

+2-2
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ public function gemm(
609609
throw new InvalidArgumentException('"A" and "C" must have the same number of rows."B" and "C" must have the same number of columns');
610610
}
611611
} else {
612-
$C = $this->zeros($this->alloc([$M,$N]));
612+
$C = $this->zeros($this->alloc([$M,$N],$A->dtype()));
613613
}
614614
$CC = $C->buffer();
615615
$offC = $C->offset();
@@ -1982,7 +1982,7 @@ public function onehot(
19821982
}
19831983
$sizeX = $X->size();
19841984
if($Y===null) {
1985-
$Y = $this->zeros($this->alloc([$sizeX,$numClass]));
1985+
$Y = $this->zeros($this->alloc([$sizeX,$numClass],$this->defaultFloatType));
19861986
}
19871987
if($Y->ndim()!=2) {
19881988
throw new InvalidArgumentException('"Y" must be 2D-NDArray.');

src/LinearAlgebraCL.php

+2-2
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ public function gemm(
12041204
throw new InvalidArgumentException('"A" and "C" must have the same number of rows."B" and "C" must have the same number of columns');
12051205
}
12061206
} else {
1207-
$C = $this->alloc([$M,$N]);
1207+
$C = $this->alloc([$M,$N],$A->dtype());
12081208
$beta = 0.0;
12091209
}
12101210
$CC = $C->buffer();
@@ -3160,7 +3160,7 @@ public function onehot(
31603160
$addMode = true;
31613161
if($Y===null) {
31623162
$addMode = false;
3163-
$Y = $this->alloc([$sizeX,$numClass]);
3163+
$Y = $this->alloc([$sizeX,$numClass],$this->defaultFloatType);
31643164
$waitPrev = $waitEvents;
31653165
$waitEvents = $this->newEventList();
31663166
$this->zeros($Y,$waitEvents,$waitPrev);

0 commit comments

Comments
 (0)