Skip to content

Commit

Permalink
fix: add pickle support for ElevationMask
Browse files Browse the repository at this point in the history
  • Loading branch information
helgee committed Nov 15, 2024
1 parent ab44a41 commit c411c78
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: 3.12
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
55 changes: 51 additions & 4 deletions crates/lox-orbits/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -999,13 +999,40 @@ impl From<ElevationMaskError> for PyErr {
}
}

#[pyclass(name = "ElevationMask", module = "lox_space", frozen)]
#[pyclass(name = "ElevationMask", module = "lox_space", frozen, eq)]
#[derive(Debug, Clone, PartialEq)]
pub struct PyElevationMask(pub ElevationMask);

#[pymethods]
impl PyElevationMask {
#[new]
#[pyo3(signature = (azimuth=None, elevation=None, min_elevation=None))]
fn new(
azimuth: Option<&Bound<'_, PyArray1<f64>>>,
elevation: Option<&Bound<'_, PyArray1<f64>>>,
min_elevation: Option<f64>,
) -> PyResult<Self> {
if let Some(min_elevation) = min_elevation {
return Ok(PyElevationMask(ElevationMask::with_fixed_elevation(
min_elevation,
)));
}
if let (Some(azimuth), Some(elevation)) = (azimuth, elevation) {
let azimuth = azimuth.to_vec()?;
let elevation = elevation.to_vec()?;
return Ok(PyElevationMask(ElevationMask::new(azimuth, elevation)?));
}
Err(PyValueError::new_err("invalid argument combination, either `min_elevation` or `azimuth` and `elevation` arrays need to be present"))
}

#[classmethod]
fn fixed(_cls: &Bound<'_, PyType>, min_elevation: f64) -> Self {
PyElevationMask(ElevationMask::with_fixed_elevation(min_elevation))
}

#[classmethod]
fn variable(
_cls: &Bound<'_, PyType>,
azimuth: &Bound<'_, PyArray1<f64>>,
elevation: &Bound<'_, PyArray1<f64>>,
) -> PyResult<Self> {
Expand All @@ -1014,9 +1041,29 @@ impl PyElevationMask {
Ok(PyElevationMask(ElevationMask::new(azimuth, elevation)?))
}

#[classmethod]
fn fixed(_cls: &Bound<'_, PyType>, min_elevation: f64) -> Self {
PyElevationMask(ElevationMask::with_fixed_elevation(min_elevation))
fn __getnewargs__(&self) -> (Option<Vec<f64>>, Option<Vec<f64>>, Option<f64>) {
(self.azimuth(), self.elevation(), self.min_elevation())
}

fn azimuth(&self) -> Option<Vec<f64>> {
match &self.0 {
ElevationMask::Fixed(_) => None,
ElevationMask::Variable(series) => Some(series.x().to_vec()),
}
}

fn elevation(&self) -> Option<Vec<f64>> {
match &self.0 {
ElevationMask::Fixed(_) => None,
ElevationMask::Variable(series) => Some(series.y().to_vec()),
}
}

fn min_elevation(&self) -> Option<f64> {
match &self.0 {
ElevationMask::Fixed(min_elevation) => Some(*min_elevation),
ElevationMask::Variable(_) => None,
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions crates/lox-space/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
use lox_bodies::python::{PyBarycenter, PyMinorBody, PyPlanet, PySatellite, PySun};
use lox_ephem::python::PySpk;
use lox_orbits::python::{
elevation, find_events, find_windows, visibility, PyEvent, PyFrame, PyGroundLocation,
PyGroundPropagator, PyKeplerian, PyObservables, PySgp4, PyState, PyTopocentric, PyTrajectory,
PyVallado, PyWindow,
elevation, find_events, find_windows, visibility, PyElevationMask, PyEvent, PyFrame,
PyGroundLocation, PyGroundPropagator, PyKeplerian, PyObservables, PySgp4, PyState,
PyTopocentric, PyTrajectory, PyVallado, PyWindow,
};
use pyo3::prelude::*;

Expand Down Expand Up @@ -50,5 +50,6 @@ fn lox_space(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PySeries>()?;
m.add_class::<PyObservables>()?;
m.add_class::<PySpk>()?;
m.add_class::<PyElevationMask>()?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/lox-space/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
lox.Planet("Earth"),
lox.Satellite("Moon"),
lox.MinorBody("Ceres"),
lox.ElevationMask.fixed(0.0),
])
def test_pickle(obj):
pickled = pickle.dumps(obj)
Expand Down

0 comments on commit c411c78

Please sign in to comment.