Skip to content

Commit

Permalink
Merge pull request #135 from iwishiwasaneagle/more-perf
Browse files Browse the repository at this point in the history
More performance improvements
  • Loading branch information
iwishiwasaneagle committed Aug 7, 2024
2 parents 72c1e76 + 1683157 commit e985ece
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
5 changes: 4 additions & 1 deletion src/jdrones/data_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright 2023 Jan-Hendrik Ewers
# SPDX-License-Identifier: GPL-3.0-only
import contextlib
import enum
from typing import Callable
from typing import Tuple

import numpy as np
import numpy.typing as npt
import pandas as pd
import pybullet as p

with contextlib.redirect_stdout(None):
import pybullet as p
import pydantic
from jdrones.maths import quat_mul
from jdrones.transforms import quat_to_euler
Expand Down
5 changes: 4 additions & 1 deletion src/jdrones/envs/base/pbdronenev.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2023 Jan-Hendrik Ewers
# SPDX-License-Identifier: GPL-3.0-only
import contextlib
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np
import pybullet as p
import pybullet_data
from gymnasium.core import ActType
from gymnasium.core import ObsType
Expand All @@ -24,6 +24,9 @@
from jdrones.types import VEC3
from jdrones.types import VEC4

with contextlib.redirect_stdout(None):
import pybullet as p


class PyBulletDroneEnv(BaseDroneEnv):
"""
Expand Down
17 changes: 9 additions & 8 deletions src/jdrones/envs/lqr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2023 Jan-Hendrik Ewers
# SPDX-License-Identifier: GPL-3.0-only
import typing
from typing import Any
from typing import Optional

Expand Down Expand Up @@ -96,15 +97,15 @@ def step(
setpoint = State.from_x(action)

action = self.controllers["lqr"](measured=self.env.state, setpoint=setpoint)
action_with_linearization_assumptions = np.sqrt(
np.clip(
self.env.model.rpyT2rpm(
[0, 0, 0, self.env.model.mass * self.env.model.g] + action
),
0,
np.inf,
)

u = typing.cast(
np.ndarray,
self.env.model.rpyT2rpm(
[0, 0, 0, self.env.model.mass * self.env.model.g] + action
),
)
u[u < 0] = 0
action_with_linearization_assumptions = np.sqrt(u)
obs, _, trunc, term, _ = self.env.step(action_with_linearization_assumptions)

return obs, 0, trunc, term, {}
Expand Down
17 changes: 15 additions & 2 deletions src/jdrones/envs/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@ class BasePositionDroneEnv(gymnasium.Env, abc.ABC):

model: URDFModel

should_calc_reward: bool
"""
Whether to calculate the reward or not. Only useful if the environment is
being used directly as the calculation can be expensive.
Default: False
"""

def __init__(
self,
model: URDFModel = DronePlus,
initial_state: State = None,
dt: float = 1 / 240,
env: LQRDroneEnv = None,
should_calc_reward: bool = False,
):
if env is None:
env = LQRDroneEnv(model=model, initial_state=initial_state, dt=dt)
Expand All @@ -53,6 +61,7 @@ def __init__(
self.action_space = spaces.Box(
low=act_bounds[:, 0], high=act_bounds[:, 1], dtype=DType
)
self.should_calc_reward = should_calc_reward

@staticmethod
def get_reward(states: States) -> float:
Expand Down Expand Up @@ -218,15 +227,19 @@ def step(
observations.append(obs.copy())

dist = np.linalg.norm(self.env.state.pos - action_as_state.pos)
if np.any(np.isnan(dist)):
if np.isnan(np.sum(dist)):
trunc = True

if dist < 0.5:
term = True
info["error"] = dist

states = States(observations)
return states, self.get_reward(states), term, trunc, info
if self.should_calc_reward:
reward = self.get_reward(states)
else:
reward = 0
return states, reward, term, trunc, info


class FifthOrderPolyPositionDroneEnv(PolynomialPositionBaseDronEnv):
Expand Down

0 comments on commit e985ece

Please sign in to comment.