Skip to content

Commit dd12f83

Browse files
authored
1.0.0-beta.4: Merge pull request #21 from Pabloo22/development
1.0.0-beta.4: Add `ResourceTaskGraphObservation` wrapper
2 parents 4aceec6 + 04ffedd commit dd12f83

12 files changed

+564
-18
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ See [this](https://colab.research.google.com/drive/1XV_Rvq1F2ns6DFG8uNj66q_rcoww
3636
Version 1.0.0 is currently in beta stage and can be installed with:
3737

3838
```bash
39-
pip install job-shop-lib==1.0.0b3
39+
pip install job-shop-lib==1.0.0b4
4040
```
4141

4242
Although this version is not stable and may contain breaking changes in subsequent releases, it is recommended to install it to access the new reinforcement learning environments and familiarize yourself with new changes (see the [latest pull requests](https://github.com/Pabloo22/job_shop_lib/pulls?q=is%3Apr+is%3Aclosed)). There is a [documentation page](https://job-shop-lib.readthedocs.io/en/latest/) for versions 1.0.0a3 and onward.

Diff for: job_shop_lib/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from job_shop_lib._base_solver import BaseSolver, Solver
2020

2121

22-
__version__ = "1.0.0-b.3"
22+
__version__ = "1.0.0-b.4"
2323

2424
__all__ = [
2525
"Operation",

Diff for: job_shop_lib/reinforcement_learning/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
RenderConfig
1414
add_padding
1515
create_edge_type_dict
16+
ResourceTaskGraphObservation
17+
ResourceTaskGraphObservationDict
1618
1719
"""
1820

@@ -31,6 +33,7 @@
3133
from job_shop_lib.reinforcement_learning._utils import (
3234
add_padding,
3335
create_edge_type_dict,
36+
map_values,
3437
)
3538

3639
from job_shop_lib.reinforcement_learning._single_job_shop_graph_env import (
@@ -39,6 +42,9 @@
3942
from job_shop_lib.reinforcement_learning._multi_job_shop_graph_env import (
4043
MultiJobShopGraphEnv,
4144
)
45+
from ._resource_task_graph_observation import (
46+
ResourceTaskGraphObservation, ResourceTaskGraphObservationDict
47+
)
4248

4349

4450
__all__ = [
@@ -52,4 +58,7 @@
5258
"add_padding",
5359
"MultiJobShopGraphEnv",
5460
"create_edge_type_dict",
61+
"ResourceTaskGraphObservation",
62+
"map_values",
63+
"ResourceTaskGraphObservationDict",
5564
]

Diff for: job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,13 @@ def ready_operations_filter(
235235
@ready_operations_filter.setter
236236
def ready_operations_filter(
237237
self,
238-
pruning_function: Callable[
238+
ready_operations_filter: Callable[
239239
[Dispatcher, List[Operation]], List[Operation]
240240
],
241241
) -> None:
242242
"""Sets the ready operations filter."""
243243
self.single_job_shop_graph_env.dispatcher.ready_operations_filter = (
244-
pruning_function
244+
ready_operations_filter
245245
)
246246

247247
@property
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""Contains wrappers for the environments."""
2+
3+
from typing import TypeVar, TypedDict
4+
from gymnasium import ObservationWrapper
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from job_shop_lib.reinforcement_learning import (
9+
ObservationDict,
10+
SingleJobShopGraphEnv,
11+
MultiJobShopGraphEnv,
12+
create_edge_type_dict,
13+
map_values,
14+
)
15+
from job_shop_lib.graphs import NodeType, JobShopGraph
16+
from job_shop_lib.exceptions import ValidationError
17+
from job_shop_lib.dispatching.feature_observers import FeatureType
18+
19+
T = TypeVar("T", bound=np.number)
20+
21+
22+
class ResourceTaskGraphObservationDict(TypedDict):
23+
"""Represents a dictionary for resource task graph observations."""
24+
25+
edge_index_dict: dict[str, NDArray[np.int64]]
26+
node_features_dict: dict[str, NDArray[np.float32]]
27+
original_ids_dict: dict[str, NDArray[np.int32]]
28+
29+
30+
# pylint: disable=line-too-long
31+
class ResourceTaskGraphObservation(ObservationWrapper):
32+
"""Observation wrapper that converts an observation following the
33+
:class:`ObservationDict` format to a format suitable to PyG's
34+
[`HeteroData`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html).
35+
36+
In particular, the ``edge_index`` is converted into a ``edge_index_dict``
37+
with keys ``(node_type_i, "to", node_type_j)``. The ``node_type_i`` and
38+
``node_type_j`` are the node types of the source and target nodes,
39+
respectively.
40+
41+
Attributes:
42+
global_to_local_id: A dictionary mapping global node IDs to local node
43+
IDs for each node type.
44+
type_ranges: A dictionary mapping node type names to (start, end) index
45+
ranges.
46+
47+
Args:
48+
env: The environment to wrap.
49+
"""
50+
51+
def __init__(self, env: SingleJobShopGraphEnv | MultiJobShopGraphEnv):
52+
super().__init__(env)
53+
self.global_to_local_id = self._compute_id_mappings()
54+
self.type_ranges = self._compute_node_type_ranges()
55+
56+
@property
57+
def job_shop_graph(self) -> JobShopGraph:
58+
"""Returns the job shop graph from the environment.
59+
60+
Raises:
61+
ValidationError: If the environment is not an instance of
62+
``SingleJobShopGraphEnv`` or ``MultiJobShopGraphEnv``.
63+
"""
64+
if isinstance(self.env, (SingleJobShopGraphEnv, MultiJobShopGraphEnv)):
65+
return self.env.job_shop_graph
66+
raise ValidationError(
67+
"The environment must be an instance of "
68+
"SingleJobShopGraphEnv or MultiJobShopGraphEnv"
69+
)
70+
71+
def step(self, action: tuple[int, int]):
72+
"""Takes a step in the environment.
73+
74+
Args:
75+
action:
76+
The action to take. The action is a tuple of two integers
77+
(job_id, machine_id):
78+
the job ID and the machine ID in which to schedule the
79+
operation.
80+
81+
Returns:
82+
A tuple containing the following elements:
83+
84+
- The observation of the environment.
85+
- The reward obtained.
86+
- Whether the environment is done.
87+
- Whether the episode was truncated (always False).
88+
- A dictionary with additional information. The dictionary
89+
contains the following keys: "feature_names", the names of the
90+
features in the observation; and "available_operations_with_ids",
91+
a list of available actions in the form of (operation_id,
92+
machine_id, job_id).
93+
"""
94+
observation, reward, done, truncated, info = self.env.step(action)
95+
return self.observation(observation), reward, done, truncated, info
96+
97+
def reset(self, *, seed: int | None = None, options: dict | None = None):
98+
"""Resets the environment.
99+
100+
Args:
101+
seed:
102+
Added to match the signature of the parent class. It is not
103+
used in this method.
104+
options:
105+
Additional options to pass to the environment. Not used in
106+
this method.
107+
108+
Returns:
109+
A tuple containing the following elements:
110+
111+
- The observation of the environment.
112+
- A dictionary with additional information, keys
113+
include: "feature_names", the names of the features in the
114+
observation; and "available_operations_with_ids", a list of
115+
available a list of available actions in the form of
116+
(operation_id, machine_id, job_id).
117+
"""
118+
observation, info = self.env.reset()
119+
return self.observation(observation), info
120+
121+
def _compute_id_mappings(self) -> dict[int, int]:
122+
"""Computes mappings from global node IDs to type-local IDs.
123+
124+
Returns:
125+
A dictionary mapping global node IDs to local node IDs for each
126+
node type.
127+
"""
128+
mappings = {}
129+
for node_type in NodeType:
130+
type_nodes = self.job_shop_graph.nodes_by_type[node_type]
131+
if not type_nodes:
132+
continue
133+
# Create mapping from global ID to local ID
134+
# (0 to len(type_nodes)-1)
135+
type_mapping = {
136+
node.node_id: local_id
137+
for local_id, node in enumerate(type_nodes)
138+
}
139+
mappings.update(type_mapping)
140+
141+
return mappings
142+
143+
def _compute_node_type_ranges(self) -> dict[str, tuple[int, int]]:
144+
"""Computes index ranges for each node type.
145+
146+
Returns:
147+
Dictionary mapping node type names to (start, end) index ranges
148+
"""
149+
type_ranges = {}
150+
for node_type in NodeType:
151+
type_nodes = self.job_shop_graph.nodes_by_type[node_type]
152+
if not type_nodes:
153+
continue
154+
start = min(node.node_id for node in type_nodes)
155+
end = max(node.node_id for node in type_nodes) + 1
156+
type_ranges[node_type.name.lower()] = (start, end)
157+
158+
return type_ranges
159+
160+
def observation(self, observation: ObservationDict):
161+
edge_index_dict = create_edge_type_dict(
162+
observation["edge_index"],
163+
type_ranges=self.type_ranges,
164+
relationship="to",
165+
)
166+
# mapping from global node ID to local node ID
167+
for key, edge_index in edge_index_dict.items():
168+
edge_index_dict[key] = map_values(
169+
edge_index, self.global_to_local_id
170+
)
171+
node_features_dict = self._create_node_features_dict(observation)
172+
node_features_dict, original_ids_dict = self._remove_nodes(
173+
node_features_dict, observation["removed_nodes"]
174+
)
175+
176+
return {
177+
"edge_index_dict": edge_index_dict,
178+
"node_features_dict": node_features_dict,
179+
"original_ids_dict": original_ids_dict,
180+
}
181+
182+
def _create_node_features_dict(
183+
self, observation: ObservationDict
184+
) -> dict[str, NDArray]:
185+
"""Creates a dictionary of node features for each node type.
186+
187+
Args:
188+
observation: The observation dictionary.
189+
190+
Returns:
191+
Dictionary mapping node type names to node features.
192+
"""
193+
node_type_to_feature_type = {
194+
NodeType.OPERATION: FeatureType.OPERATIONS,
195+
NodeType.MACHINE: FeatureType.MACHINES,
196+
NodeType.JOB: FeatureType.JOBS,
197+
}
198+
node_features_dict = {}
199+
for node_type, feature_type in node_type_to_feature_type.items():
200+
if node_type in self.job_shop_graph.nodes_by_type:
201+
node_features_dict[feature_type.value] = observation[
202+
feature_type.value
203+
]
204+
continue
205+
if feature_type != FeatureType.JOBS:
206+
continue
207+
assert FeatureType.OPERATIONS.value in observation
208+
job_features = observation[
209+
feature_type.value # type: ignore[literal-required]
210+
]
211+
job_ids_of_ops = [
212+
node.operation.job_id
213+
for node in self.job_shop_graph.nodes_by_type[
214+
NodeType.OPERATION
215+
]
216+
]
217+
job_features_expanded = job_features[job_ids_of_ops]
218+
operation_features = observation[FeatureType.OPERATIONS.value]
219+
node_features_dict[FeatureType.OPERATIONS.value] = np.concatenate(
220+
(operation_features, job_features_expanded), axis=1
221+
)
222+
return node_features_dict
223+
224+
def _remove_nodes(
225+
self,
226+
node_features_dict: dict[str, NDArray[np.float32]],
227+
removed_nodes: NDArray[np.bool_],
228+
) -> tuple[dict[str, NDArray[np.float32]], dict[str, NDArray[np.int32]]]:
229+
"""Removes nodes from the node features dictionary.
230+
231+
Args:
232+
node_features_dict: The node features dictionary.
233+
234+
Returns:
235+
The node features dictionary with the nodes removed and a
236+
dictionary containing the original node ids.
237+
"""
238+
removed_nodes_dict: dict[str, NDArray[np.float32]] = {}
239+
original_ids_dict: dict[str, NDArray[np.int32]] = {}
240+
feature_type_to_node_type = {
241+
FeatureType.OPERATIONS.value: NodeType.OPERATION,
242+
FeatureType.MACHINES.value: NodeType.MACHINE,
243+
FeatureType.JOBS.value: NodeType.JOB,
244+
}
245+
for feature_type, features in node_features_dict.items():
246+
node_type = feature_type_to_node_type[feature_type].name.lower()
247+
if node_type not in self.type_ranges:
248+
continue
249+
start, end = self.type_ranges[node_type]
250+
removed_nodes_of_this_type = removed_nodes[start:end]
251+
removed_nodes_dict[node_type] = features[
252+
~removed_nodes_of_this_type
253+
]
254+
original_ids_dict[node_type] = np.where(
255+
~removed_nodes_of_this_type
256+
)[0]
257+
258+
return removed_nodes_dict, original_ids_dict

Diff for: job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,27 @@ def reset(
243243
*,
244244
seed: Optional[int] = None,
245245
options: Optional[Dict[str, Any]] = None,
246-
) -> Tuple[ObservationDict, dict]:
247-
"""Resets the environment."""
246+
) -> Tuple[ObservationDict, dict[str, Any]]:
247+
"""Resets the environment.
248+
249+
Args:
250+
seed:
251+
Added to match the signature of the parent class. It is not
252+
used in this method.
253+
options:
254+
Additional options to pass to the environment. Not used in
255+
this method.
256+
257+
Returns:
258+
A tuple containing the following elements:
259+
260+
- The observation of the environment.
261+
- A dictionary with additional information, keys
262+
include: "feature_names", the names of the features in the
263+
observation; and "available_operations_with_ids", a list of
264+
available a list of available actions in the form of
265+
(operation_id, machine_id, job_id).
266+
"""
248267
super().reset(seed=seed, options=options)
249268
self.dispatcher.reset()
250269
obs = self.get_observation()

0 commit comments

Comments
 (0)