Skip to content

Commit 766e7ad

Browse files
authored
1.0.0-beta.5: Merge pull request #22 from Pabloo22/development
1.0.0-beta.5: Enhance `ResourceTaskGraphObservation` with generic environment type
2 parents dd12f83 + 9da307a commit 766e7ad

File tree

5 files changed

+24
-31
lines changed

5 files changed

+24
-31
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ See [this](https://colab.research.google.com/drive/1XV_Rvq1F2ns6DFG8uNj66q_rcoww
3636
Version 1.0.0 is currently in beta stage and can be installed with:
3737

3838
```bash
39-
pip install job-shop-lib==1.0.0b4
39+
pip install job-shop-lib==1.0.0b5
4040
```
4141

4242
Although this version is not stable and may contain breaking changes in subsequent releases, it is recommended to install it to access the new reinforcement learning environments and familiarize yourself with new changes (see the [latest pull requests](https://github.com/Pabloo22/job_shop_lib/pulls?q=is%3Apr+is%3Aclosed)). There is a [documentation page](https://job-shop-lib.readthedocs.io/en/latest/) for versions 1.0.0a3 and onward.

job_shop_lib/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from job_shop_lib._base_solver import BaseSolver, Solver
2020

2121

22-
__version__ = "1.0.0-b.4"
22+
__version__ = "1.0.0-b.5"
2323

2424
__all__ = [
2525
"Operation",

job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py

+17-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Contains wrappers for the environments."""
22

3-
from typing import TypeVar, TypedDict
3+
from typing import TypeVar, TypedDict, Generic
44
from gymnasium import ObservationWrapper
55
import numpy as np
66
from numpy.typing import NDArray
@@ -12,11 +12,13 @@
1212
create_edge_type_dict,
1313
map_values,
1414
)
15-
from job_shop_lib.graphs import NodeType, JobShopGraph
16-
from job_shop_lib.exceptions import ValidationError
15+
from job_shop_lib.graphs import NodeType
1716
from job_shop_lib.dispatching.feature_observers import FeatureType
1817

1918
T = TypeVar("T", bound=np.number)
19+
EnvType = TypeVar( # pylint: disable=invalid-name
20+
"EnvType", bound=SingleJobShopGraphEnv | MultiJobShopGraphEnv
21+
)
2022

2123

2224
class ResourceTaskGraphObservationDict(TypedDict):
@@ -28,7 +30,7 @@ class ResourceTaskGraphObservationDict(TypedDict):
2830

2931

3032
# pylint: disable=line-too-long
31-
class ResourceTaskGraphObservation(ObservationWrapper):
33+
class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]):
3234
"""Observation wrapper that converts an observation following the
3335
:class:`ObservationDict` format to a format suitable to PyG's
3436
[`HeteroData`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html).
@@ -48,26 +50,12 @@ class ResourceTaskGraphObservation(ObservationWrapper):
4850
env: The environment to wrap.
4951
"""
5052

51-
def __init__(self, env: SingleJobShopGraphEnv | MultiJobShopGraphEnv):
53+
def __init__(self, env: EnvType):
5254
super().__init__(env)
55+
self.env = env # Unnecessary, but makes mypy happy
5356
self.global_to_local_id = self._compute_id_mappings()
5457
self.type_ranges = self._compute_node_type_ranges()
5558

56-
@property
57-
def job_shop_graph(self) -> JobShopGraph:
58-
"""Returns the job shop graph from the environment.
59-
60-
Raises:
61-
ValidationError: If the environment is not an instance of
62-
``SingleJobShopGraphEnv`` or ``MultiJobShopGraphEnv``.
63-
"""
64-
if isinstance(self.env, (SingleJobShopGraphEnv, MultiJobShopGraphEnv)):
65-
return self.env.job_shop_graph
66-
raise ValidationError(
67-
"The environment must be an instance of "
68-
"SingleJobShopGraphEnv or MultiJobShopGraphEnv"
69-
)
70-
7159
def step(self, action: tuple[int, int]):
7260
"""Takes a step in the environment.
7361
@@ -127,7 +115,7 @@ def _compute_id_mappings(self) -> dict[int, int]:
127115
"""
128116
mappings = {}
129117
for node_type in NodeType:
130-
type_nodes = self.job_shop_graph.nodes_by_type[node_type]
118+
type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type]
131119
if not type_nodes:
132120
continue
133121
# Create mapping from global ID to local ID
@@ -148,7 +136,7 @@ def _compute_node_type_ranges(self) -> dict[str, tuple[int, int]]:
148136
"""
149137
type_ranges = {}
150138
for node_type in NodeType:
151-
type_nodes = self.job_shop_graph.nodes_by_type[node_type]
139+
type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type]
152140
if not type_nodes:
153141
continue
154142
start = min(node.node_id for node in type_nodes)
@@ -197,7 +185,7 @@ def _create_node_features_dict(
197185
}
198186
node_features_dict = {}
199187
for node_type, feature_type in node_type_to_feature_type.items():
200-
if node_type in self.job_shop_graph.nodes_by_type:
188+
if node_type in self.unwrapped.job_shop_graph.nodes_by_type:
201189
node_features_dict[feature_type.value] = observation[
202190
feature_type.value
203191
]
@@ -210,7 +198,7 @@ def _create_node_features_dict(
210198
]
211199
job_ids_of_ops = [
212200
node.operation.job_id
213-
for node in self.job_shop_graph.nodes_by_type[
201+
for node in self.unwrapped.job_shop_graph.nodes_by_type[
214202
NodeType.OPERATION
215203
]
216204
]
@@ -256,3 +244,8 @@ def _remove_nodes(
256244
)[0]
257245

258246
return removed_nodes_dict, original_ids_dict
247+
248+
@property
249+
def unwrapped(self) -> EnvType:
250+
"""Returns the unwrapped environment."""
251+
return self.env # type: ignore[return-value]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "job-shop-lib"
3-
version = "1.0.0-beta.4"
3+
version = "1.0.0-beta.5"
44
description = "An easy-to-use and modular Python library for the Job Shop Scheduling Problem (JSSP)"
55
authors = ["Pabloo22 <[email protected]>"]
66
license = "MIT"

tests/reinforcement_learning/test_resource_task_graph_observation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ def test_edge_index_dict(
1414
single_env_ft06_resource_task_graph_with_all_features
1515
)
1616
obs, info = env.reset()
17-
max_index = env.job_shop_graph.instance.num_operations
17+
max_index = env.unwrapped.job_shop_graph.instance.num_operations
1818
edge_index_dict = obs["edge_index_dict"]
1919
_check_that_edge_index_has_been_reindexed(edge_index_dict, max_index)
2020

2121
done = False
2222
_, machine_id, job_id = info["available_operations_with_ids"][0]
23-
removed_nodes = env.job_shop_graph.removed_nodes
23+
removed_nodes = env.unwrapped.job_shop_graph.removed_nodes
2424
_check_count_of_unique_ids(edge_index_dict, removed_nodes)
2525
while not done:
2626
obs, _, done, _, info = env.step((job_id, machine_id))
@@ -43,7 +43,7 @@ def test_node_features_dict(
4343
obs, info = env.reset()
4444
done = False
4545
_, machine_id, job_id = info["available_operations_with_ids"][0]
46-
removed_nodes = env.job_shop_graph.removed_nodes
46+
removed_nodes = env.unwrapped.job_shop_graph.removed_nodes
4747
_check_number_of_nodes(obs["node_features_dict"], removed_nodes)
4848
while not done:
4949
obs, _, done, _, info = env.step((job_id, machine_id))
@@ -64,7 +64,7 @@ def test_original_ids_dict(
6464
obs, info = env.reset()
6565
done = False
6666
_, machine_id, job_id = info["available_operations_with_ids"][0]
67-
removed_nodes = env.job_shop_graph.removed_nodes
67+
removed_nodes = env.unwrapped.job_shop_graph.removed_nodes
6868
_check_original_ids_dict(obs["original_ids_dict"], removed_nodes)
6969
while not done:
7070
obs, _, done, _, info = env.step((job_id, machine_id))

0 commit comments

Comments
 (0)