Skip to content

Commit

Permalink
Merge pull request #40 from iwishiwasaneagle/plotting-and-report
Browse files Browse the repository at this point in the history
  • Loading branch information
iwishiwasaneagle committed Apr 5, 2023
2 parents 43984a3 + 585d2d3 commit 398755e
Show file tree
Hide file tree
Showing 7 changed files with 444 additions and 62 deletions.
1 change: 1 addition & 0 deletions docs/examples/notebook_quick_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
print(f"\tImported gymnasium version {gymnasium.__version__}")
import jdrones
from jdrones.data_models import *
import jdrones.plotting as jplot

print(f"\tImported jdrones version {jdrones.__version__}")

Expand Down
50 changes: 10 additions & 40 deletions docs/examples/simple_position.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
gymnasium==0.27.1
loguru==0.6.0
matplotlib==3.7.1
nptyping==2.5.0
numpy==1.24.2
pandas==2.0.0
pybullet==3.2.5
pydantic==1.10.6
scipy==1.10.1
seaborn==0.12.2
57 changes: 35 additions & 22 deletions src/jdrones/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,40 @@ def iter_to_df(x, *, tag, dt, N, cols):
return df_long


class STATE_ENUM(str, enum.Enum):
X = "x"
Y = "y"
Z = "z"
QX = "qx"
QY = "qy"
QZ = "qz"
QW = "qw"
PHI = "phi"
THETA = "theta"
PSI = "psi"
VX = "vx"
VY = "vy"
VZ = "vz"
P = "p"
Q = "q"
R = "r"
P0 = "P0"
P1 = "P1"
P2 = "P2"
P3 = "P3"

@classmethod
def as_list(cls) -> list[str]:
"""
Convert the enum to a list of strings
Returns
-------
list[str]
"""
return list(map(lambda i: i.value, cls))


class States(np.ndarray):
def __new__(cls, input_array=None):
if input_array is None:
Expand All @@ -208,28 +242,7 @@ def to_df(self, *, tag, dt=1, N=500):
tag=tag,
dt=dt,
N=N,
cols=[
"x",
"y",
"z",
"qx",
"qy",
"qz",
"qw",
"phi",
"theta",
"psi",
"vx",
"vy",
"vz",
"p",
"q",
"r",
"P0",
"P1",
"P2",
"P3",
],
cols=[STATE_ENUM.as_list()],
)
return df

Expand Down
285 changes: 285 additions & 0 deletions src/jdrones/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import enum
import functools

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from jdrones.data_models import STATE_ENUM


class SUPPORTED_SEABORN_THEMES(str, enum.Enum):
WHITEGRID = "whitegrid"
DARKGRID = "darkgrid"
WHITE = "white"
DARK = "dark"
TICKS = "ticks"


def apply_seaborn_theme(
style: SUPPORTED_SEABORN_THEMES = SUPPORTED_SEABORN_THEMES.WHITEGRID,
):
"""
Apply a seaborn theme
Parameters
----------
style : SUPPORTED_SEABORN_THEMES
Returns
-------
"""
sns.set_theme(style=str(style.value))


def valid_df(df: pd.DataFrame) -> bool:
"""
Ensure the :class:`pandas.DataFrame` is one, or at least in the same shape as one,
created by :meth:`jdrones.data_models.States.to_df`.
Parameters
----------
df : pandas.DataFrame
Returns
-------
bool
True if valid, False if not
"""

is_dataframe = isinstance(df, pd.DataFrame)
if not is_dataframe:
return is_dataframe
has_expected_columns = {"t", "variable", "value", "tag"} == set(df.columns)
if not has_expected_columns:
return has_expected_columns
starts_at_0s = df.t.min() == 0.0
is_sorted_by_t = (df.iloc[0].t == 0.0) & (df.iloc[-1].t == df.t.max())
return starts_at_0s & is_sorted_by_t


def validate_df_wrapper(func):
@functools.wraps(func)
def fn(df, *args, **kwargs):
if not valid_df(df):
raise ValueError("df is invalid")
return func(df, *args, **kwargs)

return fn


def extract_state(df: pd.DataFrame, state: STATE_ENUM) -> pd.DataFrame:
"""
Extract the state from a dataframe
Parameters
----------
df : pandas.DataFrame
state : STATE_ENUM
Returns
-------
pandas.DataFrame
"""
return df[df.variable == state]


def extract_state_value(df: pd.DataFrame, state: STATE_ENUM) -> list[float]:
"""
Extract the state values from a dataframe
Parameters
----------
df : pandas.DataFrame
state : STATE_ENUM
Returns
-------
list[float]
"""
return extract_state(df, state).value


@validate_df_wrapper
def plot_state_vs_state(
df: pd.DataFrame,
state_a: STATE_ENUM,
state_b: STATE_ENUM,
ax: plt.Axes,
label: str = None,
):
"""
Plot the 2d a-b path
Parameters
----------
df : pandas.DataFrame
state_a : STATE_ENUM
state_b : STATE_ENUM
ax : matplotlib.pyplot.Axes
label : str
Optional label
(Default = None)
"""
a = extract_state_value(df, state_a)
b = extract_state_value(df, state_b)
ax.set_xlabel(state_a)
ax.set_ylabel(state_b)
if label is not None:
ax.plot(a, b, label=label)
else:
ax.plot(a, b)


@validate_df_wrapper
def plot_state_vs_state_vs_state(
df: pd.DataFrame,
state_a: STATE_ENUM,
state_b: STATE_ENUM,
state_c: STATE_ENUM,
ax: plt.Axes,
label: str = None,
):
"""
Plot the 3d a-b-c path
Parameters
----------
df : pandas.DataFrame
state_a : STATE_ENUM
state_b : STATE_ENUM
state_c : STATE_ENUM
ax : matplotlib.pyplot.Axes
label : str
Optional label
(Default = None)
"""
if not hasattr(ax, "plot3D"):
raise Exception(
f"{ax=} does not have plot3D. Ensure the correct "
"projection has been set."
)
a = extract_state_value(df, state_a)
b = extract_state_value(df, state_b)
c = extract_state_value(df, state_c)
ax.set_xlabel(state_a)
ax.set_ylabel(state_b)
ax.set_zlabel(state_c)
if label is not None:
ax.plot(a, b, c, label=label)
else:
ax.plot(a, b, c)


@validate_df_wrapper
def plot_2d_path(df: pd.DataFrame, ax: plt.Axes, label: str = None):
"""
Plot the 2d x-y path
Parameters
----------
df : pandas.DataFrame
ax : matplotlib.pyplot.Axes
label : str
Optional label
(Default = None)
"""
plot_state_vs_state(df, "x", "y", ax, label)


@validate_df_wrapper
def plot_3d_path(df: pd.DataFrame, ax: plt.Axes, label: str = None):
"""
Plot the 3d x-y-z path
Parameters
----------
df : pandas.DataFrame
ax : matplotlib.pyplot.Axes
label : str
Optional label
(Default = None)
"""
plot_state_vs_state_vs_state(df, "x", "y", "z", ax, label)
ax.set_xlabel("x (m)")
ax.set_ylabel("y (m)")
ax.set_zlabel("z (m)")


@validate_df_wrapper
def plot_state_over_time(df: pd.DataFrame, variable: STATE_ENUM, ax: plt.Axes):
"""
Plot a state over time
Parameters
----------
df : pandas.DataFrame
variable : STATE_ENUM
The state to plot
ax : matplotlib.pyplot.Axes
"""
a = extract_state(df, variable)
v, t = a.value, a.t
ax.set_ylabel(variable)
ax.set_xlabel("t")
ax.plot(t, v, label=variable)


@validate_df_wrapper
def plot_states_over_time(df: pd.DataFrame, variables: list[STATE_ENUM], ax: plt.Axes):
"""
Plot states over time
Parameters
----------
df : pandas.DataFrame
variables : list[STATE_ENUM]
A list of states to plot
ax : matplotlib.pyplot.Axes
"""
for variable in variables:
plot_state_over_time(df, variable, ax)
ax.set_xlabel("t")


@validate_df_wrapper
def plot_standard(
df: pd.DataFrame, figsize: tuple[float, float] = (12, 12), show: bool = True
):
"""
Plot the standard 2-by-2 layout
.. code::
+------------------+----------------------+
| 3D path | position vs time |
+------------------+----------------------+
| velocity vs time | euler angles vs time |
+------------------+----------------------+
Parameters
----------
df : pandas.DataFrame
figsize: float,float
Figure size
(Default = (12,12))
show : bool
If figure should be shown. Set to :code:`False` if you want to save the
figure using :code:`plt.gcf()`
(Default = :code:`True`)
"""
fig = plt.figure(figsize=figsize)

ax = fig.add_subplot(221, projection="3d")
plot_3d_path(df, ax)

for ind, states, label in (
(222, ("x", "y", "z"), "position (m)"),
(223, ("vx", "vy", "vz"), "velocity (m/s)"),
(224, ("phi", "theta", "psi"), "angular position (rad)"),
):
ax = fig.add_subplot(ind)
plot_states_over_time(df, states, ax)
ax.set_ylabel(label)
ax.legend()

fig.tight_layout()
if show:
plt.show()
Loading

0 comments on commit 398755e

Please sign in to comment.