|
40 | 40 | "AcrobotSwingupSparse": partial(acrobot.Balance, sparse=True),
|
41 | 41 | "BallInCup": ball_in_cup.BallInCup,
|
42 | 42 | "CartpoleBalance": partial(cartpole.Balance, swing_up=False, sparse=False),
|
43 |
| - "CartpoleBalanceSparse": partial(cartpole.Balance, swing_up=False, sparse=True), |
| 43 | + "CartpoleBalanceSparse": partial( |
| 44 | + cartpole.Balance, swing_up=False, sparse=True |
| 45 | + ), |
44 | 46 | "CartpoleSwingup": partial(cartpole.Balance, swing_up=True, sparse=False),
|
45 |
| - "CartpoleSwingupSparse": partial(cartpole.Balance, swing_up=True, sparse=True), |
| 47 | + "CartpoleSwingupSparse": partial( |
| 48 | + cartpole.Balance, swing_up=True, sparse=True |
| 49 | + ), |
46 | 50 | "CheetahRun": cheetah.Run,
|
47 | 51 | "FingerSpin": finger.Spin,
|
48 |
| - "FingerTurnEasy": partial(finger.Turn, target_radius=finger.EASY_TARGET_SIZE), |
49 |
| - "FingerTurnHard": partial(finger.Turn, target_radius=finger.HARD_TARGET_SIZE), |
| 52 | + "FingerTurnEasy": partial( |
| 53 | + finger.Turn, target_radius=finger.EASY_TARGET_SIZE |
| 54 | + ), |
| 55 | + "FingerTurnHard": partial( |
| 56 | + finger.Turn, target_radius=finger.HARD_TARGET_SIZE |
| 57 | + ), |
50 | 58 | "FishSwim": fish.Swim,
|
51 | 59 | "HopperHop": partial(hopper.Hopper, hopping=True),
|
52 | 60 | "HopperStand": partial(hopper.Hopper, hopping=False),
|
|
99 | 107 |
|
100 | 108 |
|
101 | 109 | def __getattr__(name):
|
102 |
| - if name == "ALL_ENVS": |
103 |
| - return tuple(_envs.keys()) |
104 |
| - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") |
| 110 | + if name == "ALL_ENVS": |
| 111 | + return tuple(_envs.keys()) |
| 112 | + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") |
105 | 113 |
|
106 | 114 |
|
107 | 115 | def register_environment(
|
108 | 116 | env_name: str,
|
109 | 117 | env_class: Type[mjx_env.MjxEnv],
|
110 | 118 | cfg_class: Callable[[], config_dict.ConfigDict],
|
111 | 119 | ) -> None:
|
112 |
| - """Register a new environment. |
| 120 | + """Register a new environment. |
113 | 121 |
|
114 |
| - Args: |
115 |
| - env_name: The name of the environment. |
116 |
| - env_class: The environment class. |
117 |
| - cfg_class: The default configuration |
118 |
| - """ |
119 |
| - _envs[env_name] = env_class |
120 |
| - _cfgs[env_name] = cfg_class |
| 122 | + Args: |
| 123 | + env_name: The name of the environment. |
| 124 | + env_class: The environment class. |
| 125 | + cfg_class: The default configuration |
| 126 | + """ |
| 127 | + _envs[env_name] = env_class |
| 128 | + _cfgs[env_name] = cfg_class |
121 | 129 |
|
122 | 130 |
|
123 | 131 | def get_default_config(env_name: str) -> config_dict.ConfigDict:
|
124 |
| - """Get the default configuration for an environment.""" |
125 |
| - if env_name not in _cfgs: |
126 |
| - raise ValueError( |
127 |
| - f"Env '{env_name}' not found in default configs. Available configs:" |
128 |
| - f" {list(_cfgs.keys())}" |
129 |
| - ) |
130 |
| - return _cfgs[env_name]() |
| 132 | + """Get the default configuration for an environment.""" |
| 133 | + if env_name not in _cfgs: |
| 134 | + raise ValueError( |
| 135 | + f"Env '{env_name}' not found in default configs. Available configs:" |
| 136 | + f" {list(_cfgs.keys())}" |
| 137 | + ) |
| 138 | + return _cfgs[env_name]() |
131 | 139 |
|
132 | 140 |
|
133 | 141 | def load(
|
134 | 142 | env_name: str,
|
135 | 143 | config: Optional[config_dict.ConfigDict] = None,
|
136 | 144 | config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
|
137 | 145 | ) -> mjx_env.MjxEnv:
|
138 |
| - """Get an environment instance with the given configuration. |
139 |
| -
|
140 |
| - Args: |
141 |
| - env_name: The name of the environment. |
142 |
| - config: The configuration to use. If not provided, the default |
143 |
| - configuration is used. |
144 |
| - config_overrides: A dictionary of overrides for the configuration. |
145 |
| -
|
146 |
| - Returns: |
147 |
| - An instance of the environment. |
148 |
| - """ |
149 |
| - if env_name not in _envs: |
150 |
| - raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") |
151 |
| - config = config or get_default_config(env_name) |
152 |
| - return _envs[env_name](config=config, config_overrides=config_overrides) |
| 146 | + """Get an environment instance with the given configuration. |
| 147 | +
|
| 148 | + Args: |
| 149 | + env_name: The name of the environment. |
| 150 | + config: The configuration to use. If not provided, the default |
| 151 | + configuration is used. |
| 152 | + config_overrides: A dictionary of overrides for the configuration. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + An instance of the environment. |
| 156 | + """ |
| 157 | + if env_name not in _envs: |
| 158 | + raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") |
| 159 | + config = config or get_default_config(env_name) |
| 160 | + return _envs[env_name](config=config, config_overrides=config_overrides) |
0 commit comments