Skip to content

Commit da39fcb

Browse files
committed
Improve load_from_thrustcurve_api and test_load_from_thrustcurve_api with clean imports
1 parent 05ab711 commit da39fcb

File tree

2 files changed

+123
-50
lines changed

2 files changed

+123
-50
lines changed

rocketpy/motors/motor.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
import xml.etree.ElementTree as ET
66
from abc import ABC, abstractmethod
77
from functools import cached_property
8-
from os import path
8+
from os import path, remove
99

1010
import numpy as np
1111
import requests
12-
12+
import logging
1313
from ..mathutils.function import Function, funcify_method
1414
from ..plots.motor_plots import _MotorPlots
1515
from ..prints.motor_prints import _MotorPrints
1616
from ..tools import parallel_axis_theorem_from_com, tuple_handler
1717

18+
logger = logging.getLogger(__name__)
19+
1820

1921
# pylint: disable=too-many-public-methods
2022
class Motor(ABC):
@@ -1916,7 +1918,7 @@ def load_from_rse_file(
19161918
interpolation_method=interpolation_method,
19171919
coordinate_system_orientation=coordinate_system_orientation,
19181920
)
1919-
1921+
19201922
@staticmethod
19211923
def load_from_thrustcurve_api(name: str, **kwargs):
19221924
"""
@@ -1926,16 +1928,25 @@ def load_from_thrustcurve_api(name: str, **kwargs):
19261928
Parameters
19271929
----------
19281930
name : str
1929-
The motor name according to the API (e.g., "Cesaroni_M1670").
1931+
The motor name according to the API (e.g., "Cesaroni_M1670" or "M1670").
1932+
Both manufacturer-prefixed and shorthand names are commonly used; if multiple
1933+
motors match the search, the first result is used.
19301934
**kwargs :
1931-
Additional arguments passed to the Motor constructor, such as dry_mass, nozzle_radius, etc.
1935+
Additional arguments passed to the Motor constructor or loader, such as
1936+
dry_mass, nozzle_radius, etc.
19321937
19331938
Returns
19341939
-------
1935-
instance : cls
1936-
A new Motor instance initialized using the downloaded .eng file.
1940+
instance : GenericMotor
1941+
A new GenericMotor instance initialized using the downloaded .eng file.
1942+
1943+
Raises
1944+
------
1945+
ValueError
1946+
If no motor is found or if the downloaded .eng data is missing.
1947+
requests.exceptions.RequestException
1948+
If a network or HTTP error occurs during the API call.
19371949
"""
1938-
19391950
base_url = "https://www.thrustcurve.org/api/v1"
19401951

19411952
# Step 1. Search motor
@@ -1944,34 +1955,55 @@ def load_from_thrustcurve_api(name: str, **kwargs):
19441955
data = response.json()
19451956

19461957
if not data.get("results"):
1947-
print("No motor found.")
1948-
return None
1958+
raise ValueError(
1959+
f"No motor found for name '{name}'. "
1960+
"Please verify the motor name format (e.g., 'Cesaroni_M1670' or 'M1670') and try again."
1961+
)
19491962

1950-
motor = data["results"][0]
1951-
motor_id = motor["motorId"]
1952-
designation = motor["designation"].replace("/", "-")
1953-
print(f"Motor found: {designation} ({motor['manufacturer']})")
1963+
motor_info = data["results"][0]
1964+
motor_id = motor_info.get("motorId")
1965+
designation = motor_info.get("designation", "").replace("/", "-")
1966+
manufacturer = motor_info.get("manufacturer", "")
1967+
# Logging the fact that the motor was found
1968+
logger.info(f"Motor found: {designation} ({manufacturer})")
19541969

19551970
# Step 2. Download the .eng file
19561971
dl_response = requests.get(
19571972
f"{base_url}/download.json",
19581973
params={"motorIds": motor_id, "format": "RASP", "data": "file"},
19591974
)
19601975
dl_response.raise_for_status()
1961-
data = dl_response.json()
1976+
dl_data = dl_response.json()
19621977

1963-
data_base64 = data["results"][0]["data"]
1964-
data_bytes = base64.b64decode(data_base64)
1965-
1966-
# Step 3. Create the motor from the .eng file
1978+
if not dl_data.get("results"):
1979+
raise ValueError(f"No .eng file found for motor '{name}' in the ThrustCurve API.")
19671980

1968-
with tempfile.NamedTemporaryFile(suffix=".eng", delete=True) as tmp_file:
1969-
tmp_file.write(data_bytes)
1970-
tmp_file.flush()
1981+
data_base64 = dl_data["results"][0].get("data")
1982+
if not data_base64:
1983+
raise ValueError(f"Downloaded .eng data for motor '{name}' is empty or invalid.")
19711984

1972-
motor = GenericMotor.load_from_eng_file(tmp_file.name, **kwargs)
1985+
data_bytes = base64.b64decode(data_base64)
19731986

1974-
return motor
1987+
# Step 3. Create the motor from the .eng file
1988+
tmp_path = None
1989+
try:
1990+
# create a temporary file that persists until we explicitly remove it
1991+
with tempfile.NamedTemporaryFile(suffix=".eng", delete=False) as tmp_file:
1992+
tmp_file.write(data_bytes)
1993+
tmp_file.flush()
1994+
tmp_path = tmp_file.name
1995+
1996+
1997+
motor_instance = GenericMotor.load_from_eng_file(tmp_path, **kwargs)
1998+
return motor_instance
1999+
finally:
2000+
# Ensuring the temporary file is removed
2001+
if tmp_path and path.exists(tmp_path):
2002+
try:
2003+
remove(tmp_path)
2004+
except OSError:
2005+
# If cleanup fails, don't raise: we don't want to mask prior exceptions.
2006+
pass
19752007

19762008
def all_info(self):
19772009
"""Prints out all data and graphs available about the Motor."""

tests/unit/motors/test_genericmotor.py

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import pytest
33
import scipy.integrate
4+
import requests
5+
import base64
6+
47

58
from rocketpy import Function, Motor
69

@@ -212,49 +215,87 @@ def test_load_from_rse_file(generic_motor):
212215
assert thrust_curve[-1][0] == 2.2 # Last point of time
213216
assert thrust_curve[-1][1] == 0.0 # Last thrust point
214217

215-
216-
def test_load_from_thrustcurve_api(generic_motor):
217-
"""Tests the GenericMotor.load_from_thrustcurve_api method.
218-
218+
def test_load_from_thrustcurve_api(monkeypatch, generic_motor):
219+
"""
220+
Tests the GenericMotor.load_from_thrustcurve_api method with mocked ThrustCurve API responses.
219221
Parameters
220222
----------
223+
monkeypatch : pytest.MonkeyPatch
224+
The pytest monkeypatch fixture for mocking.
221225
generic_motor : rocketpy.GenericMotor
222226
The GenericMotor object to be used in the tests.
227+
223228
"""
224-
# using cesaroni data as example
229+
230+
class MockResponse:
231+
def __init__(self, json_data):
232+
self._json_data = json_data
233+
234+
def json(self):
235+
return self._json_data
236+
237+
def raise_for_status(self):
238+
# Simulate a successful HTTP response (200)
239+
return None
240+
241+
# Provide mocked responses for the two endpoints: search.json and download.json
242+
def mock_get(url, params=None):
243+
if "search.json" in url:
244+
# Return a mock search result with a motorId and designation
245+
return MockResponse(
246+
{
247+
"results": [
248+
{
249+
"motorId": "12345",
250+
"designation": "Cesaroni_M1670",
251+
"manufacturer": "Cesaroni",
252+
}
253+
]
254+
}
255+
)
256+
elif "download.json" in url:
257+
# Read the local .eng file and return its base64-encoded content as the API would
258+
eng_path = "data/motors/cesaroni/Cesaroni_M1670.eng"
259+
with open(eng_path, "rb") as f:
260+
encoded = base64.b64encode(f.read()).decode("utf-8")
261+
return MockResponse({"results": [{"data": encoded}]})
262+
else:
263+
raise RuntimeError(f"Unexpected URL called in test mock: {url}")
264+
265+
monkeypatch.setattr(requests, "get", mock_get)
266+
267+
# Expected parameters from the original test
225268
burn_time = (0, 3.9)
226269
dry_mass = 5.231 - 3.101 # 2.130 kg
227270
propellant_initial_mass = 3.101
228271
chamber_radius = 75 / 1000
229272
chamber_height = 757 / 1000
230273
nozzle_radius = chamber_radius * 0.85 # 85% of chamber radius
231274

232-
# Parameters from manual testing using the SolidMotor class as a reference
233275
average_thrust = 1545.218
234276
total_impulse = 6026.350
235277
max_thrust = 2200.0
236278
exhaust_velocity = 1943.357
237279

238-
# creating motor from .eng file
239-
generic_motor = generic_motor.load_from_thrustcurve_api("M1670")
240-
241-
# testing relevant parameters
242-
assert generic_motor.burn_time == burn_time
243-
assert generic_motor.dry_mass == dry_mass
244-
assert generic_motor.propellant_initial_mass == propellant_initial_mass
245-
assert generic_motor.chamber_radius == chamber_radius
246-
assert generic_motor.chamber_height == chamber_height
247-
assert generic_motor.chamber_position == 0
248-
assert generic_motor.average_thrust == pytest.approx(average_thrust)
249-
assert generic_motor.total_impulse == pytest.approx(total_impulse)
250-
assert generic_motor.exhaust_velocity.average(*burn_time) == pytest.approx(
251-
exhaust_velocity
252-
)
253-
assert generic_motor.max_thrust == pytest.approx(max_thrust)
254-
assert generic_motor.nozzle_radius == pytest.approx(nozzle_radius)
255-
256-
# testing thrust curve
280+
# Call the method using the class (works if it's a staticmethod); using type(generic_motor)
281+
# ensures test works if the method is invoked on a GenericMotor instance in the project
282+
motor = type(generic_motor).load_from_thrustcurve_api("M1670")
283+
284+
# Assertions (same as original)
285+
assert motor.burn_time == burn_time
286+
assert motor.dry_mass == dry_mass
287+
assert motor.propellant_initial_mass == propellant_initial_mass
288+
assert motor.chamber_radius == chamber_radius
289+
assert motor.chamber_height == chamber_height
290+
assert motor.chamber_position == 0
291+
assert motor.average_thrust == pytest.approx(average_thrust)
292+
assert motor.total_impulse == pytest.approx(total_impulse)
293+
assert motor.exhaust_velocity.average(*burn_time) == pytest.approx(exhaust_velocity)
294+
assert motor.max_thrust == pytest.approx(max_thrust)
295+
assert motor.nozzle_radius == pytest.approx(nozzle_radius)
296+
297+
# testing thrust curve equality against the local .eng import (as in original test)
257298
_, _, points = Motor.import_eng("data/motors/cesaroni/Cesaroni_M1670.eng")
258-
assert generic_motor.thrust.y_array == pytest.approx(
299+
assert motor.thrust.y_array == pytest.approx(
259300
Function(points, "Time (s)", "Thrust (N)", "linear", "zero").y_array
260301
)

0 commit comments

Comments
 (0)