Skip to content

Commit b394073

Browse files
committed
Fix bug in privileged
1 parent 5542aaf commit b394073

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

ss2r/benchmark_suites/rccar/rccar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def randomize(rng):
4545
cfg = OmegaConf.to_container(cfg)
4646
in_axes = jax.tree_map(lambda _: 0, sys)
4747
sys, params = randomize(rng)
48-
return sys, in_axes, params[:, None]
48+
return sys, in_axes, params
4949

5050

5151
def rotate_coordinates(state: jnp.array, encode_angle: bool = False) -> jnp.array:

ss2r/common/logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def log_video(
118118
fps: int | float = 30,
119119
flush: bool = False,
120120
):
121-
self._writer.add_video(name, np.array(images, copy=False), step, fps=fps)
121+
images = np.array(images, copy=False)
122+
if images.ndim == 4:
123+
images = images[None]
124+
self._writer.add_video(name, images, step, fps=fps)
122125
if flush:
123126
self._writer.flush()
124127

0 commit comments

Comments
 (0)