From 5cc8fc907fcaf08eb1df8abfbbac3ab127ac45b2 Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 16:40:30 +0100 Subject: [PATCH 01/12] chore: Add new plotting module in notebook quick setup --- docs/examples/notebook_quick_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/examples/notebook_quick_setup.py b/docs/examples/notebook_quick_setup.py index 3f9c42d..b5fa190 100644 --- a/docs/examples/notebook_quick_setup.py +++ b/docs/examples/notebook_quick_setup.py @@ -16,6 +16,7 @@ print(f"\tImported gymnasium version {gymnasium.__version__}") import jdrones from jdrones.data_models import * +import jdrones.plotting as jplot print(f"\tImported jdrones version {jdrones.__version__}") From c4e5229f3f98264e6e3b5e3c1c241bcd0d2b0a0f Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 16:41:16 +0100 Subject: [PATCH 02/12] feat: Move state labels to an enum to ensure consistency across the codebase --- src/jdrones/data_models.py | 57 +++++++++++++++++++++++--------------- tests/test_data_models.py | 54 ++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/src/jdrones/data_models.py b/src/jdrones/data_models.py index f817b28..bbf8ae9 100644 --- a/src/jdrones/data_models.py +++ b/src/jdrones/data_models.py @@ -192,6 +192,40 @@ def iter_to_df(x, *, tag, dt, N, cols): return df_long +class STATE_ENUM(str, enum.Enum): + X = "x" + Y = "y" + Z = "z" + QX = "qx" + QY = "qy" + QZ = "qz" + QW = "qw" + PHI = "phi" + THETA = "theta" + PSI = "psi" + VX = "vx" + VY = "vy" + VZ = "vz" + P = "p" + Q = "q" + R = "r" + P0 = "P0" + P1 = "P1" + P2 = "P2" + P3 = "P3" + + @classmethod + def as_list(cls) -> list[str]: + """ + Convert the enum to a list of strings + + Returns + ------- + list[str] + """ + return list(map(lambda i: i.value, cls)) + + class States(np.ndarray): def __new__(cls, input_array=None): if input_array is None: @@ -208,28 +242,7 @@ def to_df(self, *, tag, dt=1, N=500): tag=tag, dt=dt, N=N, - cols=[ - "x", - "y", - "z", - "qx", - "qy", - "qz", - "qw", - "phi", - "theta", - "psi", - "vx", - "vy", - "vz", - "p", - "q", - "r", - "P0", - "P1", - "P2", - "P3", - ], + cols=[STATE_ENUM.as_list()], ) return df diff --git a/tests/test_data_models.py b/tests/test_data_models.py index d59f1bb..fdeb598 100644 --- a/tests/test_data_models.py +++ b/tests/test_data_models.py @@ -3,6 +3,7 @@ import numpy as np import pytest from jdrones.data_models import State +from jdrones.data_models import STATE_ENUM from scipy.spatial.transform import Rotation as R @@ -278,3 +279,56 @@ def test_state_quat_rotation(quat, act, exp): assert np.allclose(rotated.ang_vel, exp.ang_vel) assert np.allclose(rotated.vel, exp.vel) assert np.allclose(rotated.prop_omega, exp.prop_omega) + + +def test_as_str_list(): + assert tuple(STATE_ENUM.as_list()) == ( + "x", + "y", + "z", + "qx", + "qy", + "qz", + "qw", + "phi", + "theta", + "psi", + "vx", + "vy", + "vz", + "p", + "q", + "r", + "P0", + "P1", + "P2", + "P3", + ) + + +def test_as_str_list_fail(): + assert ( + tuple(STATE_ENUM.as_list()) + != ( + "x", + "y", + "z", + "qx", + "qy", + "qz", + "qw", + "phi", + "theta", + "psi", + "vx", + "vy", + "vz", + "p", + "q", + "r", + "P0", + "P1", + "P2", + "P3", + )[::-1] + ) From 3e0d4e62000b9a3e88b800bed53da15cf26ce3c9 Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 14:55:04 +0100 Subject: [PATCH 03/12] refactor: Change gymnasium env names + doctests --- docs/examples/simple_position.ipynb | 42 +++++++++++----------- src/jdrones/__init__.py | 5 +-- src/jdrones/envs/base/lineardronenev.py | 7 ++++ src/jdrones/envs/base/nonlineardronenev.py | 7 ++++ src/jdrones/envs/base/pbdronenev.py | 5 +++ src/jdrones/envs/lqr.py | 7 ++++ src/jdrones/envs/position.py | 11 ++++++ 7 files changed, 61 insertions(+), 23 deletions(-) diff --git a/docs/examples/simple_position.ipynb b/docs/examples/simple_position.ipynb index 1afe4f2..122d07e 100644 --- a/docs/examples/simple_position.ipynb +++ b/docs/examples/simple_position.ipynb @@ -20,8 +20,7 @@ "text": [ "Beginning notebook setup...\n", "\tAdded /home/jhewers/Documents/projects/jdrones/src to path\n", - "\tImported gymnasium version 0.27.1\n", - "\tImported jdrones version unknown\n" + "\tImported gymnasium version 0.27.1\n" ] }, { @@ -37,6 +36,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "\tImported jdrones version unknown\n", "\tImported scipy==1.7.3, numpy==1.23.5, pandas==1.3.5\n", "\tImported functools, collections and itertools\n", "\tImported tqdm (standard and trange)\n", @@ -157,7 +157,7 @@ } ], "source": [ - "pos_env = gymnasium.make(\"PositionDroneEnv-v0\", dt=dt, initial_state=initial_state)" + "first_order_env = gymnasium.make(\"FirstOrderPolyPositionDroneEnv-v0\", dt=dt, initial_state=initial_state)" ] }, { @@ -169,7 +169,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "19df477d36f445ec8d639b6ad5f19c35", + "model_id": "8c8fe78d6c5442fbb2c490388084e7e9", "version_major": 2, "version_minor": 0 }, @@ -182,7 +182,7 @@ } ], "source": [ - "pos_obs = run_sim(pos_env, wps)" + "first_order_obs = run_sim(first_order_env, wps)" ] }, { @@ -1159,7 +1159,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -1170,7 +1170,7 @@ } ], "source": [ - "plot_states(pos_obs)" + "plot_states(first_order_obs)" ] }, { @@ -1199,7 +1199,7 @@ } ], "source": [ - "poly_env = gymnasium.make(\"PolyPositionDroneEnv-v0\", dt=dt, initial_state=initial_state)" + "fifth_order_env = gymnasium.make(\"FifthOrderPolyPositionDroneEnv-v0\", dt=dt, initial_state=initial_state)" ] }, { @@ -1211,7 +1211,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7d5e9f27998d44db9ef31999d508ab6c", + "model_id": "b03299de408447d38b0de6c37bb42b88", "version_major": 2, "version_minor": 0 }, @@ -1224,7 +1224,7 @@ } ], "source": [ - "poly_obs = run_sim(poly_env, wps)" + "fifth_order_obs = run_sim(fifth_order_env, wps)" ] }, { @@ -2201,7 +2201,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -2212,7 +2212,7 @@ } ], "source": [ - "plot_states(poly_obs)" + "plot_states(fifth_order_obs)" ] }, { @@ -3197,7 +3197,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -3213,19 +3213,19 @@ "fig = plt.figure()\n", "ax = fig.add_subplot(111, projection=\"3d\")\n", "\n", - "poly_df = States(np.concatenate(poly_obs)).to_df(tag='LQR+5O-Poly', dt=dt).sort_values(\"t\")\n", + "first_order_df = States(np.concatenate(first_order_obs)).to_df(tag='LQR+5O-Poly', dt=dt).sort_values(\"t\")\n", "x, y, z = (\n", - " poly_df[poly_df.variable == \"x\"],\n", - " poly_df[poly_df.variable == \"y\"],\n", - " poly_df[poly_df.variable == \"z\"],\n", + " first_order_df[first_order_df.variable == \"x\"],\n", + " first_order_df[first_order_df.variable == \"y\"],\n", + " first_order_df[first_order_df.variable == \"z\"],\n", ")\n", "ax.plot(x.value, y.value, z.value, label=\"LQR+5O-Poly\")\n", "\n", - "pos_df = States(np.concatenate(pos_obs)).to_df(tag='LQR+1O-Poly', dt=dt).sort_values(\"t\")\n", + "fifth_order_df = States(np.concatenate(fifth_order_obs)).to_df(tag='LQR+1O-Poly', dt=dt).sort_values(\"t\")\n", "x, y, z = (\n", - " pos_df[pos_df.variable == \"x\"],\n", - " pos_df[pos_df.variable == \"y\"],\n", - " pos_df[pos_df.variable == \"z\"],\n", + " fifth_order_df[fifth_order_df.variable == \"x\"],\n", + " fifth_order_df[fifth_order_df.variable == \"y\"],\n", + " fifth_order_df[fifth_order_df.variable == \"z\"],\n", ")\n", "ax.plot(x.value, y.value, z.value, label=\"LQR+1O-Poly\")\n", "\n", diff --git a/src/jdrones/__init__.py b/src/jdrones/__init__.py index b23275c..2f7b8e7 100644 --- a/src/jdrones/__init__.py +++ b/src/jdrones/__init__.py @@ -17,12 +17,13 @@ entry_point="jdrones.envs:LQRDroneEnv", ) register( - "PositionDroneEnv-v0", + "FirstOrderPolyPositionDroneEnv-v0", entry_point="jdrones.envs:FirstOrderPolyPositionDroneEnv", ) register( - "PolyPositionDroneEnv-v0", + "FifthOrderPolyPositionDroneEnv-v0", entry_point="jdrones.envs:FifthOrderPolyPositionDroneEnv", ) + __version__ = "unknown" diff --git a/src/jdrones/envs/base/lineardronenev.py b/src/jdrones/envs/base/lineardronenev.py index 565ec8f..5bc0973 100644 --- a/src/jdrones/envs/base/lineardronenev.py +++ b/src/jdrones/envs/base/lineardronenev.py @@ -14,6 +14,13 @@ class LinearDynamicModelDroneEnv(BaseDroneEnv): + """ + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("LinearDynamicModelDroneEnv-v0") + >>> + """ + def __init__( self, model: URDFModel = DronePlus, diff --git a/src/jdrones/envs/base/nonlineardronenev.py b/src/jdrones/envs/base/nonlineardronenev.py index 704e4eb..dccfe52 100644 --- a/src/jdrones/envs/base/nonlineardronenev.py +++ b/src/jdrones/envs/base/nonlineardronenev.py @@ -12,6 +12,13 @@ class NonlinearDynamicModelDroneEnv(BaseDroneEnv): + """ + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("NonLinearDynamicModelDroneEnv-v0") + >>> + """ + @staticmethod def calc_dstate(action: PropellerAction, state: State, model: URDFModel): Inertias = np.diag(model.I) diff --git a/src/jdrones/envs/base/pbdronenev.py b/src/jdrones/envs/base/pbdronenev.py index 283e7a9..2ce3b03 100644 --- a/src/jdrones/envs/base/pbdronenev.py +++ b/src/jdrones/envs/base/pbdronenev.py @@ -29,6 +29,11 @@ class PyBulletDroneEnv(BaseDroneEnv): """ Base drone environment. Handles pybullet loading, and application of forces. Generalizes the physics to allow other models to be used. + + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("PyBulletDroneEnv-v0") + >>> """ ids: PyBulletIds diff --git a/src/jdrones/envs/lqr.py b/src/jdrones/envs/lqr.py index f3cd9a3..a21d7a6 100644 --- a/src/jdrones/envs/lqr.py +++ b/src/jdrones/envs/lqr.py @@ -17,6 +17,13 @@ class LQRDroneEnv(BaseControlledEnv): + """ + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("LQRDroneEnv-v0") + >>> + """ + def __init__( self, model: URDFModel = DronePlus, diff --git a/src/jdrones/envs/position.py b/src/jdrones/envs/position.py index 084233d..733e3c8 100644 --- a/src/jdrones/envs/position.py +++ b/src/jdrones/envs/position.py @@ -201,6 +201,12 @@ class FifthOrderPolyPositionDroneEnv(PolynomialPositionBaseDronEnv): If the time taken exceeds :math:`T`, the original target position is given as a raw input. However, if this were to happen, the distance is small enough to ensure stability. + + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("FifthOrderPolyPositionDroneEnv-v0") + >>> + """ @staticmethod @@ -253,6 +259,11 @@ class FirstOrderPolyPositionDroneEnv(PolynomialPositionBaseDronEnv): If the time taken exceeds :math:`T`, the original target position is given as a raw input. However, if this were to happen, the distance is small enough to ensure stability. + + >>> import jdrones + >>> import gymnasium + >>> gymnasium.make("FirstOrderPolyPositionDroneEnv-v0") + >>> """ @staticmethod From 09585ab624ada6abe6fb88142707ff8eafab6533 Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 14:56:21 +0100 Subject: [PATCH 04/12] fix: White space in doctest for PID --- src/jdrones/controllers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jdrones/controllers.py b/src/jdrones/controllers.py index 28edb6a..c13997e 100644 --- a/src/jdrones/controllers.py +++ b/src/jdrones/controllers.py @@ -109,9 +109,7 @@ class PID(AngleController): >>> pid = PID(1,2,3,dt=0.1) >>> pid(measured=0,setpoint=1) - 31.2 - - + 31.2 """ Kp: float From 3a1da91885d5f8e16bba000ecfe362e5b4417d80 Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 15:05:26 +0100 Subject: [PATCH 05/12] docs: Add gymnasium env names to README --- README.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fab8c8d..6fe97ac 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,23 @@ Provide a [gymnasium] style interface using a physics simulation engine ([pybull > :warning: This code is still in alpha and **will** change over time as I use it :warning: +## Environments + +The environment documentation can be found [here](https://jdrones.janhendrikewers.uk/envs.html) + +### Base Dynamics + +1. `PyBulletDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#pybulletdroneenv) +2. `NonLinearDynamicModelDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#nonlineardynamicmodeldroneenv) +3. `LinearDynamicModelDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#lineardynamicmodeldroneenv) + +### Attitude +1. `LQRDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#lqrdroneenv) + +### Position +1. `FirstOrderPolyPositionDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#firstorderpolypositiondroneenv) +2`FifthOrderPolyPositionDroneEnv-v0` [:link:](https://jdrones.janhendrikewers.uk/envs.html#fifthorderpolypositiondroneenv) + ## Development Create the local development environment: @@ -38,10 +55,12 @@ PYTHONPATH=$GIT_DIR/src python -m pytest -s -q -n auto --only-slow-integration $ - [ ] Better sensor modelling and kalman filters - [ ] Performance improvements of simulation using either compiled code or a JIT -- [ ] Better controllers -- [ ] Better trajectory generation between waypoints +- [x] Better controllers + - LQR +- [x] Better trajectory generation between waypoints + - First- and fifth-order polynomial trajectory generation - [x] Examples -- [ ] Proper integration testing +- [x] Proper integration testing - [ ] Higher fidelity motor models [gymnasium]: https://gymnasium.farama.org/ From 063c55ef35a80a9b4f4d05fe972dda36f1a1715a Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 14:56:06 +0100 Subject: [PATCH 06/12] test: Add proper doctests into files --- .github/workflows/CI.yml | 43 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index aa5ba38..21a78e4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -62,11 +62,10 @@ jobs: - name: Install dependencies if: steps.cache-virtualenv.outputs.cache-hit != 'true' run: | - pip install wheel -r requirements.txt -r tests/requirements.txt + pip install . -r requirements.txt -r tests/requirements.txt - name: Run tests run: | - PYTHONPATH=$PWD/src \ pytest tests \ --cov-report=xml \ --cov-branch \ @@ -98,7 +97,7 @@ jobs: run: pip wheel . --no-deps run-doc-tests: - name: Build docs and test for issues + name: Run doctests runs-on: ubuntu-latest steps: @@ -116,36 +115,14 @@ jobs: - name: Install dependencies run: | - sudo apt-get update && sudo apt-get -y install pandoc - python -m pip install -r docs/requirements.txt - - - name: Build docs and analyze output - run: | - PYTHONPATH=$PWD/src \ - python -m \ - sphinx \ - -b html \ - -d docs/_build/html/doctrees \ - docs \ - docs/_build/html/ \ - >build.log - - IGNORE_PATTERN=$(cat docs/sphinx_warning_ignore.txt | \ - sed -e 's/#.*$//g'| \ - sed -re '/^\s*$/d' | \ - awk '{print}' ORS='\\|' | \ - sed -e 's/\\|$//g' \ - ) - - OUTPUT=$(grep -E "[\w\s]+\.py" build.log | grep -v "$IGNORE_PATTERN") - - if [ -z "$OUTPUT" ]; then - echo $OUTPUT - exit 1 - else - echo "Everything is a-okay!" - exit 0 - fi + pip install . -r requirements.txt -r tests/requirements.txt + + - name: Run doctests + run : | + python -m doctest $(\ + find src/jdrones -iname "*.py" -not -iname "__main__.py" | \ + tr '\n' ' ' \ + ) docstr-cov: runs-on: ubuntu-latest From 1da699845553988e209ca3f451caff36e092a11a Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 16:42:00 +0100 Subject: [PATCH 07/12] feat: Plotting utility functions --- src/jdrones/plotting.py | 273 ++++++++++++++++++++++++++++++++++++++++ tests/test_plotting.py | 57 +++++++++ 2 files changed, 330 insertions(+) create mode 100644 src/jdrones/plotting.py create mode 100644 tests/test_plotting.py diff --git a/src/jdrones/plotting.py b/src/jdrones/plotting.py new file mode 100644 index 0000000..f93f142 --- /dev/null +++ b/src/jdrones/plotting.py @@ -0,0 +1,273 @@ +import enum +import functools + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from jdrones.data_models import STATE_ENUM + + +class SUPPORTED_SEABORN_THEMES(str, enum.Enum): + WHITEGRID = "whitegrid" + DARKGRID = "darkgrid" + WHITE = "white" + DARK = "dark" + TICKS = "ticks" + + +def apply_seaborn_theme( + style: SUPPORTED_SEABORN_THEMES = SUPPORTED_SEABORN_THEMES.WHITEGRID, +): + """ + Apply a seaborn theme + Parameters + ---------- + style : SUPPORTED_SEABORN_THEMES + + Returns + ------- + """ + sns.set_theme(style=str(style.value)) + + +def valid_df(df: pd.DataFrame) -> bool: + """ + Ensure the :class:`pandas.DataFrame` is one, or at least in the same shape as one, + created by :meth:`jdrones.data_models.States.to_df`. + + Parameters + ---------- + df : pandas.DataFrame + + Returns + ------- + bool + True if valid, False if not + """ + + is_dataframe = isinstance(df, pd.DataFrame) + if not is_dataframe: + return is_dataframe + has_expected_columns = {"t", "variable", "value", "tag"} == set(df.columns) + if not has_expected_columns: + return has_expected_columns + starts_at_0s = df.t.min() == 0.0 + is_sorted_by_t = (df.iloc[0].t == 0.0) & (df.iloc[-1].t == df.t.max()) + return starts_at_0s & is_sorted_by_t + + +def validate_df_wrapper(func): + @functools.wraps(func) + def fn(df, *args, **kwargs): + if not valid_df(df): + raise ValueError("df is invalid") + return func(df, *args, **kwargs) + + return fn + + +def extract_state(df: pd.DataFrame, state: STATE_ENUM) -> pd.DataFrame: + """ + Extract the state from a dataframe + + Parameters + ---------- + df : pandas.DataFrame + state : STATE_ENUM + + Returns + ------- + pandas.DataFrame + """ + return df[df.variable == state] + + +def extract_state_value(df: pd.DataFrame, state: STATE_ENUM) -> list[float]: + """ + Extract the state values from a dataframe + + Parameters + ---------- + df : pandas.DataFrame + state : STATE_ENUM + + Returns + ------- + list[float] + """ + return extract_state(df, state).value + + +@validate_df_wrapper +def plot_state_vs_state( + df: pd.DataFrame, + state_a: STATE_ENUM, + state_b: STATE_ENUM, + ax: plt.Axes, + label: str = None, +): + """ + Plot the 2d a-b path + + Parameters + ---------- + df : pandas.DataFrame + state_a : STATE_ENUM + state_b : STATE_ENUM + ax : matplotlib.pyplot.Axes + label : str + Optional label + (Default = None) + """ + a = extract_state_value(df, state_a) + b = extract_state_value(df, state_b) + ax.set_xlabel(state_a) + ax.set_ylabel(state_b) + if label is not None: + ax.plot(a, b, label=label) + else: + ax.plot(a, b) + + +@validate_df_wrapper +def plot_state_vs_state_vs_state( + df: pd.DataFrame, + state_a: STATE_ENUM, + state_b: STATE_ENUM, + state_c: STATE_ENUM, + ax: plt.Axes, + label: str = None, +): + """ + Plot the 3d a-b-c path + + Parameters + ---------- + df : pandas.DataFrame + state_a : STATE_ENUM + state_b : STATE_ENUM + state_c : STATE_ENUM + ax : matplotlib.pyplot.Axes + label : str + Optional label + (Default = None) + """ + if not hasattr(ax, "plot3D"): + raise Exception( + f"{ax=} does not have plot3D. Ensure the correct " + "projection has been set." + ) + a = extract_state_value(df, state_a) + b = extract_state_value(df, state_b) + c = extract_state_value(df, state_c) + ax.set_xlabel(state_a) + ax.set_ylabel(state_b) + ax.set_zlabel(state_c) + if label is not None: + ax.plot(a, b, c, label=label) + else: + ax.plot(a, b, c) + + +@validate_df_wrapper +def plot_2d_path(df: pd.DataFrame, ax: plt.Axes, label: str = None): + """ + Plot the 2d x-y path + + Parameters + ---------- + df : pandas.DataFrame + ax : matplotlib.pyplot.Axes + label : str + Optional label + (Default = None) + """ + plot_state_vs_state(df, "x", "y", ax, label) + + +@validate_df_wrapper +def plot_3d_path(df: pd.DataFrame, ax: plt.Axes, label: str = None): + """ + Plot the 3d x-y-z path + + Parameters + ---------- + df : pandas.DataFrame + ax : matplotlib.pyplot.Axes + label : str + Optional label + (Default = None) + """ + plot_state_vs_state_vs_state(df, "x", "y", "z", ax, label) + + +@validate_df_wrapper +def plot_state_over_time(df: pd.DataFrame, variable: STATE_ENUM, ax: plt.Axes): + """ + Plot a state over time + + Parameters + ---------- + df : pandas.DataFrame + variable : STATE_ENUM + The state to plot + ax : matplotlib.pyplot.Axes + """ + a = extract_state(df, variable) + v, t = a.value, a.t + ax.set_ylabel(variable) + ax.set_xlabel("t") + ax.plot(t, v, label=variable) + + +@validate_df_wrapper +def plot_states_over_time(df: pd.DataFrame, variables: list[STATE_ENUM], ax: plt.Axes): + """ + Plot states over time + + Parameters + ---------- + df : pandas.DataFrame + variables : list[STATE_ENUM] + A list of states to plot + ax : matplotlib.pyplot.Axes + """ + for variable in variables: + plot_state_over_time(df, variable, ax) + ax.set_xlabel("t") + + +@validate_df_wrapper +def plot_standard(df: pd.DataFrame, figsize: tuple[float, float] = (12, 12)): + """ + Plot the standard 2-by-2 layout + + .. code:: + +------------------+----------------------+ + | 3D path | position vs time | + +------------------+----------------------+ + | velocity vs time | euler angles vs time | + +------------------+----------------------+ + + Parameters + ---------- + df : pandas.DataFrame + figsize: float,float + Figure size + (Default = (12,12)) + """ + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(221, projection="3d") + plot_3d_path(df, ax) + + for ind, states in ( + (222, ("x", "y", "z")), + (223, ("vx", "vy", "vz")), + (224, ("phi", "theta", "psi")), + ): + ax = fig.add_subplot(ind) + plot_states_over_time(df, states, ax) + ax.legend() + + fig.tight_layout() + plt.show() diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 0000000..77f1cc1 --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +from jdrones.data_models import STATE_ENUM +from jdrones.data_models import States +from jdrones.plotting import extract_state +from jdrones.plotting import extract_state_value +from jdrones.plotting import valid_df + + +@pytest.fixture(params=[(5, 20), (100, 20)]) +def states(request): + return States(np.random.rand(*request.param)) + + +@pytest.fixture(params=["test"]) +def tag(request): + return request.param + + +@pytest.fixture +def dataframe(states, tag, dt): + return states.to_df(tag=tag, dt=dt) + + +def test_valid_df(dataframe): + assert valid_df(dataframe) + + +def test_invalid_df_by_sort(dataframe): + assert not valid_df(dataframe.iloc[::-1]) + + +def test_invalid_df_by_type(dataframe): + assert not valid_df(int(1)) + + +@pytest.mark.parametrize("col", ["t", "variable", "value", "tag"]) +def test_invalid_df_by_column(dataframe, col): + assert not valid_df(dataframe.drop(columns=[col])) + + +ALL_STATES = pytest.mark.parametrize( + "i,variable", + tuple(enumerate(STATE_ENUM.as_list())), +) + + +@ALL_STATES +def test_extract_variable(dataframe, states, i, variable): + df = extract_state(dataframe, variable) + assert np.allclose(df.value, states[:, i]) + + +@ALL_STATES +def test_extract_variable_value(dataframe, states, i, variable): + value = extract_state_value(dataframe, variable) + assert np.allclose(value, states[:, i]) From ec60216988de9623348777f9874815346317175c Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 16:58:26 +0100 Subject: [PATCH 08/12] chore: Use new plotting utility functions --- docs/examples/simple_position.ipynb | 50 ++++++----------------------- 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/docs/examples/simple_position.ipynb b/docs/examples/simple_position.ipynb index 122d07e..8b0b35b 100644 --- a/docs/examples/simple_position.ipynb +++ b/docs/examples/simple_position.ipynb @@ -103,24 +103,7 @@ "source": [ "def plot_states(states):\n", " df = States(np.concatenate(states)).to_df(tag='Observations',dt=dt)\n", - " fig, ax = plt.subplots(2,2,figsize=(10,8))\n", - " ax = ax.flatten()\n", - "\n", - " sns.lineplot(data=df.query(\"variable in ('x','y','z')\"), x='t',y='value',hue='variable', style='tag',ax=ax[0])\n", - " ax[0].legend()\n", - "\n", - " sns.lineplot(data=df.query(\"variable in ('phi','theta','psi')\"), x='t',y='value',hue='variable', style='tag',ax=ax[1])\n", - " ax[1].legend()\n", - "\n", - " sns.lineplot(data=df.query(\"variable in ('vx','vy','vz')\"), x='t',y='value',hue='variable', style='tag',ax=ax[2])\n", - " ax[2].legend()\n", - "\n", - " sns.lineplot(data=df.query(\"variable in ('P0','P1','P2','P3')\"), x='t',y='value',hue='variable', style='tag',ax=ax[3])\n", - " ax[3].legend()\n", - "\n", - " fig.tight_layout()\n", - " \n", - " plt.show()" + " jplot.plot_standard(df)" ] }, { @@ -169,7 +152,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8c8fe78d6c5442fbb2c490388084e7e9", + "model_id": "95cb598457e64207b45fa3600ee69ff1", "version_major": 2, "version_minor": 0 }, @@ -1159,7 +1142,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -1211,7 +1194,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b03299de408447d38b0de6c37bb42b88", + "model_id": "c2edb276e07843f29f3840976a3ac5a2", "version_major": 2, "version_minor": 0 }, @@ -2201,7 +2184,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -3197,7 +3180,7 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" @@ -3213,24 +3196,11 @@ "fig = plt.figure()\n", "ax = fig.add_subplot(111, projection=\"3d\")\n", "\n", - "first_order_df = States(np.concatenate(first_order_obs)).to_df(tag='LQR+5O-Poly', dt=dt).sort_values(\"t\")\n", - "x, y, z = (\n", - " first_order_df[first_order_df.variable == \"x\"],\n", - " first_order_df[first_order_df.variable == \"y\"],\n", - " first_order_df[first_order_df.variable == \"z\"],\n", - ")\n", - "ax.plot(x.value, y.value, z.value, label=\"LQR+5O-Poly\")\n", - "\n", - "fifth_order_df = States(np.concatenate(fifth_order_obs)).to_df(tag='LQR+1O-Poly', dt=dt).sort_values(\"t\")\n", - "x, y, z = (\n", - " fifth_order_df[fifth_order_df.variable == \"x\"],\n", - " fifth_order_df[fifth_order_df.variable == \"y\"],\n", - " fifth_order_df[fifth_order_df.variable == \"z\"],\n", - ")\n", - "ax.plot(x.value, y.value, z.value, label=\"LQR+1O-Poly\")\n", + "first_order_df = States(np.concatenate(first_order_obs)).to_df(tag='LQR+5O-Poly', dt=dt)\n", + "fifth_order_df = States(np.concatenate(fifth_order_obs)).to_df(tag='LQR+1O-Poly', dt=dt)\n", "\n", - "wps = np.array(wps)\n", - "ax.scatter(wps.T[0], wps.T[1], wps.T[2],marker='x', label=\"WP\")\n", + "jplot.plot_3d_path(fifth_order_df, ax, \"LQR+50-Poly\")\n", + "jplot.plot_3d_path(first_order_df, ax, \"LQR+10-Poly\")\n", "\n", "ax.legend()\n", "\n", From 72f99ae1903af9c04917dc58a3df16bdac0efec2 Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 28 Mar 2023 17:26:51 +0100 Subject: [PATCH 09/12] chore: Add matplotlib and seaborn as a requirement --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 6c366e5..bc231b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ gymnasium==0.27.1 loguru==0.6.0 +matplotlib==3.7.1 nptyping==2.5.0 numpy==1.24.2 pandas==1.5.3 pybullet==3.2.5 pydantic==1.10.6 scipy==1.10.1 +seaborn==0.12.2 From 8c004df83fbce2069e09f0a15420ffbac5e345df Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Tue, 4 Apr 2023 10:59:47 +0100 Subject: [PATCH 10/12] fix: Option to not show, to enable saving of figure --- src/jdrones/plotting.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/jdrones/plotting.py b/src/jdrones/plotting.py index f93f142..d156aea 100644 --- a/src/jdrones/plotting.py +++ b/src/jdrones/plotting.py @@ -238,7 +238,9 @@ def plot_states_over_time(df: pd.DataFrame, variables: list[STATE_ENUM], ax: plt @validate_df_wrapper -def plot_standard(df: pd.DataFrame, figsize: tuple[float, float] = (12, 12)): +def plot_standard( + df: pd.DataFrame, figsize: tuple[float, float] = (12, 12), show: bool = True +): """ Plot the standard 2-by-2 layout @@ -255,6 +257,10 @@ def plot_standard(df: pd.DataFrame, figsize: tuple[float, float] = (12, 12)): figsize: float,float Figure size (Default = (12,12)) + show : bool + If figure should be shown. Set to :code:`False` if you want to save the + figure using :code:`plt.gcf()` + (Default = :code:`True`) """ fig = plt.figure(figsize=figsize) ax = fig.add_subplot(221, projection="3d") @@ -270,4 +276,5 @@ def plot_standard(df: pd.DataFrame, figsize: tuple[float, float] = (12, 12)): ax.legend() fig.tight_layout() - plt.show() + if show: + plt.show() From f6c03aa04de9f4f8eb1461fe5ca7dabfa2f2cc8f Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Wed, 5 Apr 2023 15:20:10 +0100 Subject: [PATCH 11/12] chore: Proper axis labels for 3D path plot --- src/jdrones/plotting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jdrones/plotting.py b/src/jdrones/plotting.py index d156aea..cf2c3e8 100644 --- a/src/jdrones/plotting.py +++ b/src/jdrones/plotting.py @@ -199,6 +199,9 @@ def plot_3d_path(df: pd.DataFrame, ax: plt.Axes, label: str = None): (Default = None) """ plot_state_vs_state_vs_state(df, "x", "y", "z", ax, label) + ax.set_xlabel("x (m)") + ax.set_ylabel("y (m)") + ax.set_zlabel("z (m)") @validate_df_wrapper From 585d2d3cc535e9a5b9394f16f79589b88525d4bf Mon Sep 17 00:00:00 2001 From: iwishiwasaneagle Date: Wed, 5 Apr 2023 15:20:43 +0100 Subject: [PATCH 12/12] chore: Proper axis labels for standard plots --- src/jdrones/plotting.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/jdrones/plotting.py b/src/jdrones/plotting.py index cf2c3e8..a9cc92f 100644 --- a/src/jdrones/plotting.py +++ b/src/jdrones/plotting.py @@ -266,16 +266,18 @@ def plot_standard( (Default = :code:`True`) """ fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(221, projection="3d") plot_3d_path(df, ax) - for ind, states in ( - (222, ("x", "y", "z")), - (223, ("vx", "vy", "vz")), - (224, ("phi", "theta", "psi")), + for ind, states, label in ( + (222, ("x", "y", "z"), "position (m)"), + (223, ("vx", "vy", "vz"), "velocity (m/s)"), + (224, ("phi", "theta", "psi"), "angular position (rad)"), ): ax = fig.add_subplot(ind) plot_states_over_time(df, states, ax) + ax.set_ylabel(label) ax.legend() fig.tight_layout()