Skip to content

Commit 8f85580

Browse files
committed
Formatting
1 parent 72f467e commit 8f85580

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

mujoco_playground/_src/dm_control_suite/__init__.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,21 @@
4040
"AcrobotSwingupSparse": partial(acrobot.Balance, sparse=True),
4141
"BallInCup": ball_in_cup.BallInCup,
4242
"CartpoleBalance": partial(cartpole.Balance, swing_up=False, sparse=False),
43-
"CartpoleBalanceSparse": partial(cartpole.Balance, swing_up=False, sparse=True),
43+
"CartpoleBalanceSparse": partial(
44+
cartpole.Balance, swing_up=False, sparse=True
45+
),
4446
"CartpoleSwingup": partial(cartpole.Balance, swing_up=True, sparse=False),
45-
"CartpoleSwingupSparse": partial(cartpole.Balance, swing_up=True, sparse=True),
47+
"CartpoleSwingupSparse": partial(
48+
cartpole.Balance, swing_up=True, sparse=True
49+
),
4650
"CheetahRun": cheetah.Run,
4751
"FingerSpin": finger.Spin,
48-
"FingerTurnEasy": partial(finger.Turn, target_radius=finger.EASY_TARGET_SIZE),
49-
"FingerTurnHard": partial(finger.Turn, target_radius=finger.HARD_TARGET_SIZE),
52+
"FingerTurnEasy": partial(
53+
finger.Turn, target_radius=finger.EASY_TARGET_SIZE
54+
),
55+
"FingerTurnHard": partial(
56+
finger.Turn, target_radius=finger.HARD_TARGET_SIZE
57+
),
5058
"FishSwim": fish.Swim,
5159
"HopperHop": partial(hopper.Hopper, hopping=True),
5260
"HopperStand": partial(hopper.Hopper, hopping=False),
@@ -99,54 +107,54 @@
99107

100108

101109
def __getattr__(name):
102-
if name == "ALL_ENVS":
103-
return tuple(_envs.keys())
104-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
110+
if name == "ALL_ENVS":
111+
return tuple(_envs.keys())
112+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
105113

106114

107115
def register_environment(
108116
env_name: str,
109117
env_class: Type[mjx_env.MjxEnv],
110118
cfg_class: Callable[[], config_dict.ConfigDict],
111119
) -> None:
112-
"""Register a new environment.
120+
"""Register a new environment.
113121
114-
Args:
115-
env_name: The name of the environment.
116-
env_class: The environment class.
117-
cfg_class: The default configuration
118-
"""
119-
_envs[env_name] = env_class
120-
_cfgs[env_name] = cfg_class
122+
Args:
123+
env_name: The name of the environment.
124+
env_class: The environment class.
125+
cfg_class: The default configuration
126+
"""
127+
_envs[env_name] = env_class
128+
_cfgs[env_name] = cfg_class
121129

122130

123131
def get_default_config(env_name: str) -> config_dict.ConfigDict:
124-
"""Get the default configuration for an environment."""
125-
if env_name not in _cfgs:
126-
raise ValueError(
127-
f"Env '{env_name}' not found in default configs. Available configs:"
128-
f" {list(_cfgs.keys())}"
129-
)
130-
return _cfgs[env_name]()
132+
"""Get the default configuration for an environment."""
133+
if env_name not in _cfgs:
134+
raise ValueError(
135+
f"Env '{env_name}' not found in default configs. Available configs:"
136+
f" {list(_cfgs.keys())}"
137+
)
138+
return _cfgs[env_name]()
131139

132140

133141
def load(
134142
env_name: str,
135143
config: Optional[config_dict.ConfigDict] = None,
136144
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
137145
) -> mjx_env.MjxEnv:
138-
"""Get an environment instance with the given configuration.
139-
140-
Args:
141-
env_name: The name of the environment.
142-
config: The configuration to use. If not provided, the default
143-
configuration is used.
144-
config_overrides: A dictionary of overrides for the configuration.
145-
146-
Returns:
147-
An instance of the environment.
148-
"""
149-
if env_name not in _envs:
150-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
151-
config = config or get_default_config(env_name)
152-
return _envs[env_name](config=config, config_overrides=config_overrides)
146+
"""Get an environment instance with the given configuration.
147+
148+
Args:
149+
env_name: The name of the environment.
150+
config: The configuration to use. If not provided, the default
151+
configuration is used.
152+
config_overrides: A dictionary of overrides for the configuration.
153+
154+
Returns:
155+
An instance of the environment.
156+
"""
157+
if env_name not in _envs:
158+
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
159+
config = config or get_default_config(env_name)
160+
return _envs[env_name](config=config, config_overrides=config_overrides)

0 commit comments

Comments
 (0)