Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert blob generator to numpower #324

Draft
wants to merge 1 commit into
base: 3.0
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Exceptions\InvalidArgumentException;
use NDArray as nd;

use function count;
use function sqrt;
Expand All @@ -32,14 +33,12 @@
*
* @var Vector
*/
protected \Tensor\Vector $center;
protected nd $center;

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

PHPDoc tag @var for property Rubix\ML\Datasets\Generators\Blob::$center with type Tensor\Vector is not subtype of native type NDArray.

Check failure on line 36 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$center has unknown class NDArray as its type.

/**
* The standard deviation of the blob.
*
* @var \Tensor\Vector|int|float
*/
protected $stdDev;
protected int|float|nd $stdDev;

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

Check failure on line 41 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Property Rubix\ML\Datasets\Generators\Blob::$stdDev has unknown class NDArray as its type.

/**
* Fit a Blob generator to the samples in a dataset.
Expand Down Expand Up @@ -94,15 +93,15 @@
}
}

$stdDev = Vector::quick($stdDev);
$stdDev = nd::array($stdDev);

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 96 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.
} else {
if ($stdDev < 0) {
throw new InvalidArgumentException('Standard deviation'
. " must be greater than 0, $stdDev given.");
}
}

$this->center = Vector::quick($center);
$this->center = nd::array($center);

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method array() on an unknown class NDArray.

Check failure on line 104 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method array() on an unknown class NDArray.
$this->stdDev = $stdDev;
}

Expand All @@ -113,7 +112,7 @@
*/
public function center() : array
{
return $this->center->asArray();
return $this->center->toArray();

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.

Check failure on line 115 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to method toArray() on an unknown class NDArray.
}

/**
Expand All @@ -125,7 +124,7 @@
*/
public function dimensions() : int
{
return $this->center->n();
return $this->center->size();

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to method size() on an unknown class NDArray.

Check failure on line 127 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to method size() on an unknown class NDArray.
}

/**
Expand All @@ -138,11 +137,10 @@
{
$d = $this->dimensions();

$samples = Matrix::gaussian($n, $d)
->multiply($this->stdDev)
->add($this->center)
->asArray();
$samples = nd::normal([$n, $d]);

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.

Check failure on line 140 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method normal() on an unknown class NDArray.
$samplesMul = nd::multiply($samples, $this->stdDev);

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.

Check failure on line 141 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method multiply() on an unknown class NDArray.
$samplesMulAddCenter = nd::add($samplesMul, $this->center);

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on windows-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

Check failure on line 142 in src/Datasets/Generators/Blob.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Call to static method add() on an unknown class NDArray.

return Unlabeled::quick($samples);
return Unlabeled::quick($samplesMulAddCenter->toArray());
}
}
Loading