diff --git a/python/egobox/tests/test_gpmix.py b/python/egobox/tests/test_gpmix.py index ab3254bb..e5ce99d2 100644 --- a/python/egobox/tests/test_gpmix.py +++ b/python/egobox/tests/test_gpmix.py @@ -22,11 +22,11 @@ def griewank(x): class TestGpMix(unittest.TestCase): def setUp(self): - xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T - yt = np.array([[0.0, 1.0, 1.5, 0.9, 1.0]]).T + self.xt = np.array([[0.0, 1.0, 2.0, 3.0, 4.0]]).T + self.yt = np.array([[0.0, 1.0, 1.5, 0.9, 1.0]]).T gpmix = egx.GpMix() # or egx.Gpx.builder() - self.gpx = gpmix.fit(xt, yt) + self.gpx = gpmix.fit(self.xt, self.yt) def test_gpx_kriging(self): gpx = self.gpx @@ -76,6 +76,12 @@ def test_gpx_save_load(self): 0.0, gpx2.predict_var(np.array([[1.1]])).item(), delta=1e-3 ) + def test_training_params(self): + self.assertEquals(self.gpx.dims(), (1, 1)) + (xdata, ydata) = self.gpx.training_data() + np.testing.assert_array_equal(xdata, self.xt) + np.testing.assert_array_equal(ydata, self.yt) + def test_kpls_griewank(self): lb = -600 ub = 600 @@ -106,6 +112,7 @@ def test_kpls_griewank(self): for builder in builders: gpx = builder.fit(x_train, y_train) y_pred = gpx.predict(x_test) + self.assertEqual(100, gpx.dims()[0]) error = np.linalg.norm(y_pred - y_test) / np.linalg.norm(y_test) print(" RMS error: " + str(error)) diff --git a/src/gp_mix.rs b/src/gp_mix.rs index 446e5f53..03806afd 100644 --- a/src/gp_mix.rs +++ b/src/gp_mix.rs @@ -10,6 +10,7 @@ //! See the [tutorial notebook](https://github.com/relf/egobox/doc/Gpx_Tutorial.ipynb) for usage. //! use crate::types::*; +use egobox_gp::metrics::CrossValScore; use egobox_moe::{Clustered, MixtureGpSurrogate, ThetaTuning}; #[allow(unused_imports)] // Avoid linting problem use egobox_moe::{GpMixture, GpSurrogate, GpSurrogateExt}; @@ -356,6 +357,31 @@ impl Gpx { .into_pyarray_bound(py) } + /// Get the input and output dimensions of the surrogate + /// + /// Returns + /// the couple (nx, ny) + /// + fn dims(&self) -> (usize, usize) { + self.0.dims() + } + + /// Get the nt training data points used to fit the surrogate + /// + /// Returns + /// the couple (ndarray[nt, nx], ndarray[nt, ny]) + /// + fn training_data<'py>( + &self, + py: Python<'py>, + ) -> (Bound<'py, PyArray2>, Bound<'py, PyArray2>) { + let (xdata, ydata) = self.0.training_data(); + ( + xdata.to_owned().into_pyarray_bound(py), + ydata.to_owned().into_pyarray_bound(py), + ) + } + /// Get optimized thetas hyperparameters (ie once GP experts are fitted) /// /// Returns