Skip to content

Commit

Permalink
chore: convert blog generator to numpower
Browse files Browse the repository at this point in the history
  • Loading branch information
mcharytoniuk committed Feb 14, 2024
1 parent 57e1811 commit f1317c6
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Exceptions\InvalidArgumentException;
use NDArray as nd;

use function count;
use function sqrt;
Expand All @@ -32,14 +33,12 @@ class Blob implements Generator
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected nd $center;

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

/**
* The standard deviation of the blob.
*
* @var \Tensor\Vector|int|float
*/
protected $stdDev;
protected int|float|nd $stdDev;

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

/**
* Fit a Blob generator to the samples in a dataset.
Expand Down Expand Up @@ -94,15 +93,15 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0)
}
}

$stdDev = Vector::quick($stdDev);
$stdDev = nd::array($stdDev);

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.
} else {
if ($stdDev < 0) {
throw new InvalidArgumentException('Standard deviation'
. " must be greater than 0, $stdDev given.");
}
}

$this->center = Vector::quick($center);
$this->center = nd::array($center);

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.
$this->stdDev = $stdDev;
}

Expand All @@ -113,7 +112,7 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0)
*/
public function center() : array
{
return $this->center->asArray();
return $this->center->toArray();

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.
}

/**
Expand All @@ -125,7 +124,7 @@ public function center() : array
*/
public function dimensions() : int
{
return $this->center->n();
return $this->center->size();

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to method size() on an unknown class NDArray.
}

/**
Expand All @@ -138,11 +137,10 @@ public function generate(int $n) : Unlabeled
{
$d = $this->dimensions();

$samples = Matrix::gaussian($n, $d)
->multiply($this->stdDev)
->add($this->center)
->asArray();
$samples = nd::normal([$n, $d]);

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.
$samplesMul = nd::multiply($samples, $this->stdDev);

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.
$samplesMulAddCenter = nd::add($samplesMul, $this->center);

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

return Unlabeled::quick($samples);
return Unlabeled::quick($samplesMulAddCenter->toArray());
}
}

0 comments on commit f1317c6

Please sign in to comment.