Skip to content

Commit

Permalink
Make mean and covariance properties on MultivariateGaussian private
Browse files Browse the repository at this point in the history
  • Loading branch information
Skeletonxf committed Dec 1, 2024
1 parent c0ade81 commit 4541cbe
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 38 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ can be updated when using Easy ML 2.0 or later to the following:
fn function_name<T: Real>()
where for<'a> &'a T: RealRef<T> {}
```
- The public properties `mean` and `covariance` on the MultivariateGaussian
struct were made private and methods with the same names were added to return
references to the vector and matrix. This allows the `draw` method to not have
to recheck invariants every time it is called, now matching the
MultivariateGaussianTensor version.

Further trait inheritance changes are planned as detailed at
https://github.com/Skeletonxf/easy-ml/issues/1 but not yet implemented.
Expand Down
67 changes: 31 additions & 36 deletions src/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,8 @@ where
*/
#[derive(Clone, Debug)]
pub struct MultivariateGaussian<T: Real> {
// TODO: Make these non public in 2.0 so we don't have to check them again each time we call
// `draw()`
/**
* The mean is a column vector of expected values in each dimension
*/
pub mean: Matrix<T>,
/**
* The covariance matrix is a NxN matrix where N is the number of dimensions for
* this Gaussian. A covariance matrix must always be symmetric, that is `C[i,j] = C[j,i]`.
*
* The covariance matrix is a measure of how much values from each dimension vary
* from their expected value with respect to each other.
*
* For a 2 dimensional multivariate Gaussian the covariance matrix could be the 2x2 identity
* matrix:
*
* ```ignore
* [
* 1.0, 0.0
* 0.0, 1.0
* ]
* ```
*
* In which case the two dimensions are completely uncorrelated as `C[0,1] = C[1,0] = 0`.
*/
pub covariance: Matrix<T>,
mean: Matrix<T>,
covariance: Matrix<T>,
}

impl<T: Real> MultivariateGaussian<T> {
Expand Down Expand Up @@ -317,6 +293,33 @@ impl<T: Real> MultivariateGaussian<T> {
);
MultivariateGaussian { mean, covariance }
}

/**
* The mean is a column vector of expected values in each dimension
*/
pub fn mean(&self) -> &Matrix<T> {
&self.mean
}

/**
* The covariance matrix is a measure of how much values from each dimension vary
* from their expected value with respect to each other.
*
* For a 2 dimensional multivariate Gaussian the covariance matrix could be the 2x2 identity
* matrix:
*
* ```ignore
* [
* 1.0, 0.0
* 0.0, 1.0
* ]
* ```
*
* In which case the two dimensions are completely uncorrelated as `C[0,1] = C[1,0] = 0`.
*/
pub fn covariance(&self) -> &Matrix<T> {
&self.covariance
}
}

impl<T: Real> MultivariateGaussian<T>
Expand Down Expand Up @@ -344,17 +347,9 @@ where
where
I: Iterator<Item = T>,
{
// Since both our fields are public, we have to recheck they're still meeting our
// invariants before doing any calculations.
if self.mean.columns() != 1
|| self.covariance.rows() != self.covariance.columns()
|| self.mean.rows() != self.covariance.rows()
{
return None;
}
use crate::interop::{DimensionNames, RowAndColumn, TensorRefMatrix};
// Since we already validated our state, we wouldn't expect these conversions to fail
// but if they do return None
// Since we already validated our state on construction, we wouldn't expect these
// conversions to fail but if they do return None
// Convert the column vector to a 1 dimensional tensor by selecting the sole column
let mean = crate::tensors::views::TensorIndex::from(
TensorRefMatrix::from(&self.mean).ok()?,
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ mod distributions {
.draw(&mut random_source.into_iter(), max_samples)
.unwrap();

let mean = &function.mean;
let covariance = &function.covariance;
let mean = function.mean();
let covariance = function.covariance();

// the mean of the drawn samples should be very close to our mean vector
// check the mean of the samples are within 1 standard deviation
Expand Down

0 comments on commit 4541cbe

Please sign in to comment.