Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: expand encoders implementation to support full flights.
Browse files Browse the repository at this point in the history
MNT: Add encoding feature to CHANGELOG.

BUG: add dill to the requirements file.
phmbressan authored and Gui-FernandesBR committed Dec 20, 2024
1 parent 2218f0f commit 07032f3
Showing 8 changed files with 161 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@ Attention: The newest changes should be on top -->
- ENH: create a dataset of pre-registered motors. See #664 [#744](https://github.com/RocketPy-Team/RocketPy/pull/744)
- DOC: add Defiance flight example [#742](https://github.com/RocketPy-Team/RocketPy/pull/742)
- ENH: Allow for Alternative and Custom ODE Solvers. [#748](https://github.com/RocketPy-Team/RocketPy/pull/748)
- ENH: Expansion of Encoders Implementation for Full Flights. [#679](https://github.com/RocketPy-Team/RocketPy/pull/679)


### Changed
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -5,3 +5,4 @@ netCDF4>=1.6.4
requests
pytz
simplekml
dill
24 changes: 15 additions & 9 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import json
import types
from datetime import datetime

import numpy as np

from rocketpy.mathutils.function import Function


class RocketPyEncoder(json.JSONEncoder):
"""NOTE: This is still under construction, please don't use it yet."""
"""Custom JSON encoder for RocketPy objects. It defines how to encode
different types of objects to a JSON supported format."""

def default(self, o):
if isinstance(
@@ -33,11 +32,18 @@ def default(self, o):
return float(o)
elif isinstance(o, np.ndarray):
return o.tolist()
elif isinstance(o, datetime):
return o.isoformat()
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif hasattr(o, "to_dict"):
return o.to_dict()
# elif isinstance(o, Function):
# return o.__dict__()
elif isinstance(o, (Function, types.FunctionType)):
return repr(o)
elif hasattr(o, "__dict__"):
exception_set = {"prints", "plots"}
return {
key: value
for key, value in o.__dict__.items()
if key not in exception_set
}
else:
return json.JSONEncoder.default(self, o)
return super().default(o)
20 changes: 13 additions & 7 deletions rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
@@ -366,12 +366,15 @@ def __initialize_constants(self):
self.standard_g = 9.80665
self.__weather_model_map = WeatherModelMapping()
self.__atm_type_file_to_function_map = {
("forecast", "GFS"): fetch_gfs_file_return_dataset,
("forecast", "NAM"): fetch_nam_file_return_dataset,
("forecast", "RAP"): fetch_rap_file_return_dataset,
("forecast", "HIRESW"): fetch_hiresw_file_return_dataset,
("ensemble", "GEFS"): fetch_gefs_ensemble,
# ("ensemble", "CMC"): fetch_cmc_ensemble,
"forecast": {
"GFS": fetch_gfs_file_return_dataset,
"NAM": fetch_nam_file_return_dataset,
"RAP": fetch_rap_file_return_dataset,
"HIRESW": fetch_hiresw_file_return_dataset,
},
"ensemble": {
"GEFS": fetch_gefs_ensemble,
},
}
self.__standard_atmosphere_layers = {
"geopotential_height": [ # in geopotential m
@@ -1270,7 +1273,10 @@ def set_atmospheric_model( # pylint: disable=too-many-statements
self.process_windy_atmosphere(file)
elif type in ["forecast", "reanalysis", "ensemble"]:
dictionary = self.__validate_dictionary(file, dictionary)
fetch_function = self.__atm_type_file_to_function_map.get((type, file))
try:
fetch_function = self.__atm_type_file_to_function_map[type][file]
except KeyError:
fetch_function = None

# Fetches the dataset using OpenDAP protocol or uses the file path
dataset = fetch_function() if fetch_function is not None else file
49 changes: 49 additions & 0 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
@@ -5,14 +5,17 @@
carefully as it may impact all the rest of the project.
"""

import base64
import warnings
import zlib
from bisect import bisect_left
from collections.abc import Iterable
from copy import deepcopy
from functools import cached_property
from inspect import signature
from pathlib import Path

import dill
import matplotlib.pyplot as plt
import numpy as np
from scipy import integrate, linalg, optimize
@@ -3418,6 +3421,52 @@ def __validate_extrapolation(self, extrapolation):
extrapolation = "natural"
return extrapolation

def to_dict(self):
"""Serializes the Function instance to a dictionary.
Returns
-------
dict
A dictionary containing the Function's attributes.
"""
source = self.source

if callable(source):
source = zlib.compress(base64.b85encode(dill.dumps(source))).hex()

return {
"source": source,
"title": self.title,
"inputs": self.__inputs__,
"outputs": self.__outputs__,
"interpolation": self.__interpolation__,
"extrapolation": self.__extrapolation__,
}

@classmethod
def from_dict(cls, func_dict):
"""Creates a Function instance from a dictionary.
Parameters
----------
func_dict
The JSON like Function dictionary.
"""
source = func_dict["source"]
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
source = dill.loads(
base64.b85decode(zlib.decompress(bytes.fromhex(source)))
)

return cls(
source=source,
interpolation=func_dict["interpolation"],
extrapolation=func_dict["extrapolation"],
inputs=func_dict["inputs"],
outputs=func_dict["outputs"],
title=func_dict["title"],
)


def funcify_method(*args, **kwargs): # pylint: disable=too-many-statements
"""Decorator factory to wrap methods as Function objects and save them as
16 changes: 16 additions & 0 deletions rocketpy/motors/tank_geometry.py
Original file line number Diff line number Diff line change
@@ -346,6 +346,22 @@ def add_geometry(self, domain, radius_function):
self._geometry[domain] = Function(radius_function)
self.radius = PiecewiseFunction(self._geometry, "Height (m)", "radius (m)")

def to_dict(self):
"""
Returns a dictionary representation of the TankGeometry object.
Returns
-------
dict
Dictionary representation of the TankGeometry object.
"""
return {
"geometry": {
str(domain): function.to_dict()
for domain, function in self._geometry.items()
}
}


class CylindricalTank(TankGeometry):
"""Class to define the geometry of a cylindrical tank. The cylinder has
15 changes: 15 additions & 0 deletions rocketpy/rocket/components.py
Original file line number Diff line number Diff line change
@@ -193,3 +193,18 @@ def sort_by_position(self, reverse=False):
None
"""
self._components.sort(key=lambda x: x.position.z, reverse=reverse)

def to_dict(self):
"""Return a dictionary representation of the components.
Returns
-------
dict
A dictionary representation of the components.
"""
return {
"components": [
{"component": c.component, "position": c.position}
for c in self._components
]
}
51 changes: 51 additions & 0 deletions tests/unit/test_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json

import pytest

from rocketpy._encoders import RocketPyEncoder

# TODO: this tests should be improved with better validation and decoding


@pytest.mark.parametrize("flight_name", ["flight_calisto", "flight_calisto_robust"])
def test_encode_flight(flight_name, request):
"""Test encoding a ``rocketpy.Flight``.
Parameters
----------
flight_name : str
Name flight fixture to encode.
request : pytest.FixtureRequest
Pytest request object.
"""
flight = request.getfixturevalue(flight_name)

json_encoded = json.dumps(flight, cls=RocketPyEncoder)

flight_dict = json.loads(json_encoded)

assert json_encoded is not None
assert flight_dict is not None


@pytest.mark.parametrize(
"function_name", ["lambda_quad_func", "spline_interpolated_func"]
)
def test_encode_function(function_name, request):
"""Test encoding a ``rocketpy.Function``.
Parameters
----------
function_name : str
Name of the function to encode.
request : pytest.FixtureRequest
Pytest request object.
"""
function = request.getfixturevalue(function_name)

json_encoded = json.dumps(function, cls=RocketPyEncoder)

function_dict = json.loads(json_encoded)

assert json_encoded is not None
assert function_dict is not None

0 comments on commit 07032f3

Please sign in to comment.