From f1317c68c017d7491e010e0657ec514841ce10d4 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Wed, 14 Feb 2024 02:11:30 +0100 Subject: [PATCH] chore: convert blog generator to numpower --- src/Datasets/Generators/Blob.php | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/Datasets/Generators/Blob.php b/src/Datasets/Generators/Blob.php index 994d6fa52..5380f3a98 100644 --- a/src/Datasets/Generators/Blob.php +++ b/src/Datasets/Generators/Blob.php @@ -9,6 +9,7 @@ use Rubix\ML\Datasets\Dataset; use Rubix\ML\Datasets\Unlabeled; use Rubix\ML\Exceptions\InvalidArgumentException; +use NDArray as nd; use function count; use function sqrt; @@ -32,14 +33,12 @@ class Blob implements Generator * * @var Vector */ - protected \Tensor\Vector $center; + protected nd $center; /** * The standard deviation of the blob. - * - * @var \Tensor\Vector|int|float */ - protected $stdDev; + protected int|float|nd $stdDev; /** * Fit a Blob generator to the samples in a dataset. @@ -94,7 +93,7 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) } } - $stdDev = Vector::quick($stdDev); + $stdDev = nd::array($stdDev); } else { if ($stdDev < 0) { throw new InvalidArgumentException('Standard deviation' @@ -102,7 +101,7 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) } } - $this->center = Vector::quick($center); + $this->center = nd::array($center); $this->stdDev = $stdDev; } @@ -113,7 +112,7 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0) */ public function center() : array { - return $this->center->asArray(); + return $this->center->toArray(); } /** @@ -125,7 +124,7 @@ public function center() : array */ public function dimensions() : int { - return $this->center->n(); + return $this->center->size(); } /** @@ -138,11 +137,10 @@ public function generate(int $n) : Unlabeled { $d = $this->dimensions(); - $samples = Matrix::gaussian($n, $d) - ->multiply($this->stdDev) - ->add($this->center) - ->asArray(); + $samples = nd::normal([$n, $d]); + $samplesMul = nd::multiply($samples, $this->stdDev); + $samplesMulAddCenter = nd::add($samplesMul, $this->center); - return Unlabeled::quick($samples); + return Unlabeled::quick($samplesMulAddCenter->toArray()); } }