Skip to content

Commit

Permalink
feat: switch to ruff and upgrade pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Nov 4, 2024
1 parent 66dfc93 commit 0acf557
Show file tree
Hide file tree
Showing 196 changed files with 902 additions and 2,244 deletions.
43 changes: 14 additions & 29 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
default_stages: [ "commit", "commit-msg", "push" ]
default_stages: [ "pre-commit", "commit-msg", "pre-push" ]
default_language_version:
python: python3


repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.11.5
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.2
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
name: "Code formatter"
# Run the linter.
- id: ruff
types_or: [ python ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
name: "End of file fixer"
Expand All @@ -32,22 +33,6 @@ repos:
- id: trailing-whitespace
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
name: "Linter"
args:
- --config=setup.cfg
additional_dependencies:
- pep8-naming
- flake8-builtins
- flake8-comprehensions
- flake8-bugbear
- flake8-pytest-style
- flake8-cognitive-complexity
- importlib-metadata<5.0

- repo: local
hooks:
- id: mypy
Expand All @@ -58,15 +43,15 @@ repos:
pass_filenames: false

- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v4.1.0
rev: v9.18.0
hooks:
- id: commitlint
name: "Commit linter"
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: "License inserter"
Expand Down
120 changes: 61 additions & 59 deletions examples/load_checkpoints.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions examples/training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"\n",
"# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n",
"try:\n",
" subprocess.check_output('nvidia-smi')\n",
" subprocess.check_output(\"nvidia-smi\")\n",
" print(\"a GPU is connected.\")\n",
"except Exception:\n",
" # TPU or CPU\n",
Expand Down Expand Up @@ -82,6 +82,7 @@
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"from jumanji.training.train import train\n",
Expand Down Expand Up @@ -117,7 +118,7 @@
},
"outputs": [],
"source": [
"#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n",
"# @title Download Jumanji Configs (run me) { display-mode: \"form\" }\n",
"\n",
"import os\n",
"import requests\n",
Expand Down Expand Up @@ -407,7 +408,15 @@
],
"source": [
"with initialize(version_base=None, config_path=\"configs\"):\n",
" cfg = compose(config_name=\"config.yaml\", overrides=[f\"env={env}\", f\"agent={agent}\", \"logger.type=terminal\", \"logger.save_checkpoint=true\"])\n",
" cfg = compose(\n",
" config_name=\"config.yaml\",\n",
" overrides=[\n",
" f\"env={env}\",\n",
" f\"agent={agent}\",\n",
" \"logger.type=terminal\",\n",
" \"logger.save_checkpoint=true\",\n",
" ],\n",
" )\n",
"\n",
"train(cfg)"
]
Expand Down
8 changes: 2 additions & 6 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,8 @@
register(id="PacMan-v1", entry_point="jumanji.environments:PacMan")

# SlidingTilePuzzle - A sliding tile puzzle environment with the default grid size of 5x5.
register(
id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle"
)
register(id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle")

# LevelBasedForaging with a random generator with 8 grid size,
# 2 agents and 2 food items and the maximum agent's level is 2.
register(
id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging"
)
register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging")
16 changes: 6 additions & 10 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __repr__(self) -> str:

def __init__(self) -> None:
"""Initialize environment."""
self.observation_spec
self.action_spec
self.reward_spec
self.discount_spec
self.observation_spec # noqa: B018
self.action_spec # noqa: B018
self.reward_spec # noqa: B018
self.discount_spec # noqa: B018

@abc.abstractmethod
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
Expand All @@ -67,9 +67,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""

@abc.abstractmethod
def step(
self, state: State, action: chex.Array
) -> Tuple[State, TimeStep[Observation]]:
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
"""Run one timestep of the environment's dynamics.
Args:
Expand Down Expand Up @@ -115,9 +113,7 @@ def discount_spec(self) -> specs.BoundedArray:
Returns:
discount_spec: a `specs.BoundedArray` spec.
"""
return specs.BoundedArray(
shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount"
)
return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount")

@property
def unwrapped(self) -> Environment[State, ActionSpec, Observation]:
Expand Down
9 changes: 3 additions & 6 deletions jumanji/environments/commons/maze_utils/maze_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
nodes) through a vertical wall must be at an even y coordinate while a passage through a horizontal
wall must be at an even x coordinate.
"""

from typing import NamedTuple, Tuple

import chex
Expand Down Expand Up @@ -123,9 +124,7 @@ def create_chamber(chambers: Stack, x: int, y: int, width: int, height: int) ->
return new_stack


def split_vertically(
state: MazeGenerationState, chamber: chex.Array
) -> MazeGenerationState:
def split_vertically(state: MazeGenerationState, chamber: chex.Array) -> MazeGenerationState:
"""Split the chamber vertically.
Randomly draw a horizontal wall to split the chamber vertically. Randomly open a passage
Expand Down Expand Up @@ -215,8 +214,6 @@ def generate_maze(width: int, height: int, key: chex.PRNGKey) -> chex.Array:

initial_state = MazeGenerationState(maze, chambers, key)

final_state = jax.lax.while_loop(
chambers_remaining, split_next_chamber, initial_state
)
final_state = jax.lax.while_loop(chambers_remaining, split_next_chamber, initial_state)

return final_state.maze
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def test_random_odd(self, key: chex.PRNGKey) -> None:
assert i % 2 == 1
assert 0 <= i < max_val

def test_split_vertically(
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
) -> None:
def test_split_vertically(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None:
"""Test that a horizontal wall is drawn and that subchambers are added to stack."""
chambers, chamber = stack_pop(chambers)
state = MazeGenerationState(maze, chambers, key)
Expand All @@ -124,9 +122,7 @@ def test_split_vertically(

assert chambers.insertion_index >= 1

def test_split_horizontally(
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
) -> None:
def test_split_horizontally(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None:
"""Test that a vertical wall is drawn and that subchambers are added to stack."""
chambers, chamber = stack_pop(chambers)
state = MazeGenerationState(maze, chambers, key)
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/commons/maze_utils/maze_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Tuple
from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Tuple

import chex
import matplotlib.animation
Expand All @@ -32,7 +32,7 @@ class MazeViewer(Viewer):
FONT_STYLE = "monospace"
FIGURE_SIZE = (10.0, 10.0)
# EMPTY is white, WALL is black
COLORS = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]}
COLORS: ClassVar[Dict[int, List[int]]] = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]}

def __init__(self, name: str, render_mode: str = "human") -> None:
"""Viewer for a maze environment.
Expand Down
1 change: 1 addition & 0 deletions jumanji/environments/commons/maze_utils/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
[. . . .]]
"""

from typing import NamedTuple, Tuple

import chex
Expand Down
16 changes: 4 additions & 12 deletions jumanji/environments/logic/game_2048/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ class Game2048(Environment[State, specs.DiscreteArray, Observation]):
```
"""

def __init__(
self, board_size: int = 4, viewer: Optional[Viewer[State]] = None
) -> None:
def __init__(self, board_size: int = 4, viewer: Optional[Viewer[State]] = None) -> None:
"""Initialize the 2048 game.
Args:
Expand Down Expand Up @@ -166,9 +164,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:

return state, timestep

def step(
self, state: State, action: chex.Array
) -> Tuple[State, TimeStep[Observation]]:
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
"""Updates the environment state after the agent takes an action.
Args:
Expand Down Expand Up @@ -279,9 +275,7 @@ def _add_random_cell(self, board: Board, key: chex.PRNGKey) -> Board:
position = jnp.divmod(tile_idx, self.board_size)

# Choose the value of the new cell: 1 with probability 90% or 2 with probability of 10%
cell_value = jax.random.choice(
subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1])
)
cell_value = jax.random.choice(subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1]))
board = board.at[position].set(cell_value)

return board
Expand Down Expand Up @@ -325,9 +319,7 @@ def animate(
Returns:
animation.FuncAnimation: the animation object that was created.
"""
return self._viewer.animate(
states=states, interval=interval, save_path=save_path
)
return self._viewer.animate(states=states, interval=interval, save_path=save_path)

def close(self) -> None:
"""Perform any necessary cleanup.
Expand Down
12 changes: 3 additions & 9 deletions jumanji/environments/logic/game_2048/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def can_move_left_row_cond(carry: CanMoveCarry) -> chex.Numeric:
def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry:
"""Check if the current tiles can move and increment the indices."""
# Check if tiles can move
can_move = (carry.origin != 0) & (
(carry.target == 0) | (carry.target == carry.origin)
)
can_move = (carry.origin != 0) & ((carry.target == 0) | (carry.target == carry.origin))

# Increment indices as if performed a no op
# If not performing no op, loop will be terminated anyways
Expand All @@ -75,17 +73,13 @@ def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry:
)

# Return updated carry
return carry._replace(
can_move=can_move, target_idx=target_idx, origin_idx=origin_idx
)
return carry._replace(can_move=can_move, target_idx=target_idx, origin_idx=origin_idx)


def can_move_left_row(row: chex.Array) -> bool:
"""Check if row can move left."""
carry = CanMoveCarry(can_move=False, row=row, target_idx=0, origin_idx=1)
can_move: bool = jax.lax.while_loop(
can_move_left_row_cond, can_move_left_row_body, carry
)[0]
can_move: bool = jax.lax.while_loop(can_move_left_row_cond, can_move_left_row_body, carry)[0]
return can_move


Expand Down
12 changes: 4 additions & 8 deletions jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple
from typing import ClassVar, Dict, Optional, Sequence, Tuple

import jax.numpy as jnp
import matplotlib.animation
Expand All @@ -24,7 +24,7 @@


class Game2048Viewer(Viewer):
COLORS = {
COLORS: ClassVar[Dict[int | str, str]] = {
1: "#ccc0b3",
2: "#eee4da",
4: "#ede0c8",
Expand Down Expand Up @@ -158,13 +158,9 @@ def render_tile(self, tile_value: int, ax: plt.Axes, row: int, col: int) -> None
"""
# Set the background color of the tile based on its value.
if tile_value <= 16384:
rect = plt.Rectangle(
[col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)]
)
rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)])
else:
rect = plt.Rectangle(
[col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"]
)
rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"])
ax.add_patch(rect)

if tile_value in [2, 4]:
Expand Down
Loading

0 comments on commit 0acf557

Please sign in to comment.