-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalone_python_dial_mpc_utils.py
More file actions
144 lines (119 loc) · 4.25 KB
/
alone_python_dial_mpc_utils.py
File metadata and controls
144 lines (119 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Dict
import jax
import jax.numpy as jnp
from brax import math
# ---------------------------------------------------------------------------
# Kinematics helpers
# ---------------------------------------------------------------------------
def global_to_body_velocity(v, q):
"""将全局速度转换到机体坐标系。"""
return math.inv_rotate(v, q)
def body_to_global_velocity(v, q):
"""将机体速度转换回全局坐标系。"""
return math.rotate(v, q)
@jax.jit
def get_foot_step(duty_ratio, cadence, amplitude, phases, time):
"""根据步态参数生成足端高度曲线。"""
def step_height(t, footphase, duty_ratio):
angle = (t + jnp.pi - footphase) % (2 * jnp.pi) - jnp.pi
angle = jnp.where(duty_ratio < 1, angle * 0.5 / (1 - duty_ratio), angle)
clipped_angle = jnp.clip(angle, -jnp.pi / 2, jnp.pi / 2)
value = jnp.where(duty_ratio < 1, jnp.cos(clipped_angle), 0)
final_value = jnp.where(jnp.abs(value) >= 1e-6, jnp.abs(value), 0.0)
return final_value
h_steps = amplitude * jax.vmap(step_height, in_axes=(None, 0, None))(
time * 2 * jnp.pi * cadence + jnp.pi,
2 * jnp.pi * phases,
duty_ratio,
)
return h_steps
# ---------------------------------------------------------------------------
# Resource discovery helpers
# ---------------------------------------------------------------------------
def _script_root() -> Path:
main = sys.modules.get("__main__")
if main and getattr(main, "__file__", None):
return Path(main.__file__).resolve().parent
return Path.cwd()
def _candidate_roots() -> list[Path]:
root = _script_root()
candidates = [root]
candidates.extend(root.parents)
extra = set()
extra.add(root / "dial_mpc")
extra.add(root / "dial_mpc-simple")
extra.add(root / "dial_mpc-simple" / "dial_mpc")
extra.add(root / "dial-mpc-simple")
extra.add(root / "dial-mpc-simple" / "dial_mpc")
for parent in root.parents:
extra.add(parent / "dial_mpc")
extra.add(parent / "dial_mpc-simple" / "dial_mpc")
extra.add(parent / "dial-mpc-simple" / "dial_mpc")
extra.add(parent / "dial_mpc-simple")
candidates.extend(extra)
return [c for c in candidates if c]
def get_model_path(robot_name: str, model_name: str) -> Path:
"""定位机器人模型文件。"""
for base in _candidate_roots():
for rel in [
Path("models"),
Path("dial_mpc") / "models",
Path("robots"),
Path("dial-mpc") / "models",
]:
candidate = (base / rel / robot_name / model_name).resolve()
if candidate.exists():
return candidate
raise FileNotFoundError(
f"Cannot locate model '{model_name}' for robot '{robot_name}'."
)
def get_example_path(example_name: str) -> Path:
"""定位示例配置文件。"""
example_name = example_name if example_name.endswith(".yaml") else example_name + ".yaml"
for base in _candidate_roots():
for rel in [
Path("examples"),
Path("dial_mpc") / "examples",
Path("dial-mpc") / "examples",
]:
candidate = (base / rel / example_name).resolve()
if candidate.exists():
return candidate
raise FileNotFoundError(
f"Cannot locate example YAML '{example_name}'."
)
def load_dataclass_from_dict(
dataclass_type,
data_dict: Dict[str, Any],
convert_list_to_array: bool = False,
):
"""将字典数据加载为 dataclass,支持自动转换 JAX 数组。"""
keys = dataclass_type.__dataclass_fields__.keys() & data_dict.keys()
kwargs = {key: data_dict[key] for key in keys}
if convert_list_to_array:
for key, value in list(kwargs.items()):
if isinstance(value, list):
kwargs[key] = jnp.array(value)
return dataclass_type(**kwargs)
LegID = {
"FR_0": 0,
"FR_1": 1,
"FR_2": 2,
"FL_0": 3,
"FL_1": 4,
"FL_2": 5,
"RR_0": 6,
"RR_1": 7,
"RR_2": 8,
"RL_0": 9,
"RL_1": 10,
"RL_2": 11,
}
HIGHLEVEL = 0xEE
LOWLEVEL = 0xFF
TRIGERLEVEL = 0xF0
PosStopF = 2.146e9
VelStopF = 16000.0