Skip to content

Commit 8b286ce

Browse files
authored
Merge branch 'dev' into 7499-torchio-transforms-wrapper
2 parents 2a7842d + 20372f0 commit 8b286ce

21 files changed

+1012
-117
lines changed

docs/source/networks.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ Nets
630630
.. autoclass:: ViTAutoEnc
631631
:members:
632632

633+
`MaskedAutoEncoderViT`
634+
~~~~~~~~~~~~~~~~~~~~~~
635+
.. autoclass:: MaskedAutoEncoderViT
636+
:members:
637+
633638
`FullyConnectedNet`
634639
~~~~~~~~~~~~~~~~~~~
635640
.. autoclass:: FullyConnectedNet

monai/bundle/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@
4343
MACRO_KEY,
4444
load_bundle_config,
4545
)
46-
from .workflows import BundleWorkflow, ConfigWorkflow
46+
from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow

monai/bundle/reference_resolver.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
192192
"""
193193
return self._resolve_one_item(id=id, **kwargs)
194194

195+
def remove_resolved_content(self, id: str) -> Any | None:
196+
"""
197+
Remove the resolved ``ConfigItem`` by id.
198+
199+
Args:
200+
id: id name of the expected item.
201+
202+
"""
203+
return self.resolved_content.pop(id) if id in self.resolved_content else None
204+
195205
@classmethod
196206
def normalize_id(cls, id: str | int) -> str:
197207
"""

monai/bundle/workflows.py

Lines changed: 187 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,18 @@ class BundleWorkflow(ABC):
4444
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
4545
or "infer", "inference", "eval", "evaluation" for a inference workflow,
4646
other unsupported string will raise a ValueError.
47-
default to `train` for train workflow.
47+
default to `None` for only using meta properties.
4848
workflow: specifies the workflow type: "train" or "training" for a training workflow,
4949
or "infer", "inference", "eval", "evaluation" for a inference workflow,
5050
other unsupported string will raise a ValueError.
5151
default to `None` for common workflow.
52-
properties_path: the path to the JSON file of properties.
52+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
53+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
54+
properties will default to loading from "meta". If `properties_path` is None, default properties
55+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
56+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
57+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
58+
For workflow_type = None : only `MetaProperties` will be loaded.
5359
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
5460
logging_file: config file for `logging` module in the program. for more details:
5561
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
@@ -97,29 +103,50 @@ def __init__(
97103
meta_file = None
98104

99105
workflow_type = workflow if workflow is not None else workflow_type
100-
if workflow_type is None and properties_path is None:
101-
self.properties = copy(MetaProperties)
102-
self.workflow_type = None
103-
self.meta_file = meta_file
104-
return
106+
if workflow_type is not None:
107+
if workflow_type.lower() in self.supported_train_type:
108+
workflow_type = "train"
109+
elif workflow_type.lower() in self.supported_infer_type:
110+
workflow_type = "infer"
111+
else:
112+
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
113+
105114
if properties_path is not None:
106115
properties_path = Path(properties_path)
107116
if not properties_path.is_file():
108117
raise ValueError(f"Property file {properties_path} does not exist.")
109118
with open(properties_path) as json_file:
110-
self.properties = json.load(json_file)
111-
self.workflow_type = None
112-
self.meta_file = meta_file
113-
return
114-
if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
115-
self.properties = {**TrainProperties, **MetaProperties}
116-
self.workflow_type = "train"
117-
elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
118-
self.properties = {**InferProperties, **MetaProperties}
119-
self.workflow_type = "infer"
119+
try:
120+
properties = json.load(json_file)
121+
self.properties: dict = {}
122+
if workflow_type is not None and workflow_type in properties:
123+
self.properties = properties[workflow_type]
124+
if "meta" in properties:
125+
self.properties.update(properties["meta"])
126+
elif workflow_type is None:
127+
if "meta" in properties:
128+
self.properties = properties["meta"]
129+
logger.info(
130+
"No workflow type specified, default to load meta properties from property file."
131+
)
132+
else:
133+
logger.warning("No 'meta' key found in properties while workflow_type is None.")
134+
except KeyError as e:
135+
raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e
136+
except json.JSONDecodeError as e:
137+
raise ValueError(f"Error decoding JSON from property file {properties_path}") from e
120138
else:
121-
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
139+
if workflow_type == "train":
140+
self.properties = {**TrainProperties, **MetaProperties}
141+
elif workflow_type == "infer":
142+
self.properties = {**InferProperties, **MetaProperties}
143+
elif workflow_type is None:
144+
self.properties = copy(MetaProperties)
145+
logger.info("No workflow type and property file specified, default to 'meta' properties.")
146+
else:
147+
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
122148

149+
self.workflow_type = workflow_type
123150
self.meta_file = meta_file
124151

125152
@abstractmethod
@@ -226,6 +253,124 @@ def check_properties(self) -> list[str] | None:
226253
return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]
227254

228255

256+
class PythonicWorkflow(BundleWorkflow):
257+
"""
258+
Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.
259+
It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.
260+
This also provides the interface to get / set public properties to interact with a bundle workflow through
261+
defined `get_<property>` accessor methods or directly defining members of the object.
262+
For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.
263+
The `initialize` method is called to set up the workflow before running. This method sets up internal state
264+
and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`
265+
is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is
266+
properly set up with the new property values.
267+
268+
Args:
269+
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
270+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
271+
other unsupported string will raise a ValueError.
272+
default to `None` for only using meta properties.
273+
workflow: specifies the workflow type: "train" or "training" for a training workflow,
274+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
275+
other unsupported string will raise a ValueError.
276+
default to `None` for common workflow.
277+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
278+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
279+
properties will default to loading from "meta". If `properties_path` is None, default properties
280+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
281+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
282+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
283+
For workflow_type = None : only `MetaProperties` will be loaded.
284+
config_file: path to the config file, typically used to store hyperparameters.
285+
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
286+
logging_file: config file for `logging` module in the program. for more details:
287+
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
288+
289+
"""
290+
291+
supported_train_type: tuple = ("train", "training")
292+
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
293+
294+
def __init__(
295+
self,
296+
workflow_type: str | None = None,
297+
properties_path: PathLike | None = None,
298+
config_file: str | Sequence[str] | None = None,
299+
meta_file: str | Sequence[str] | None = None,
300+
logging_file: str | None = None,
301+
**override: Any,
302+
):
303+
meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file
304+
super().__init__(
305+
workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file
306+
)
307+
self._props_vals: dict = {}
308+
self._set_props_vals: dict = {}
309+
self.parser = ConfigParser()
310+
if config_file is not None:
311+
self.parser.read_config(f=config_file)
312+
if self.meta_file is not None:
313+
self.parser.read_meta(f=self.meta_file)
314+
315+
# the rest key-values in the _args are to override config content
316+
self.parser.update(pairs=override)
317+
self._is_initialized: bool = False
318+
319+
def initialize(self, *args: Any, **kwargs: Any) -> Any:
320+
"""
321+
Initialize the bundle workflow before running.
322+
"""
323+
self._props_vals = {}
324+
self._is_initialized = True
325+
326+
def _get_property(self, name: str, property: dict) -> Any:
327+
"""
328+
With specified property name and information, get the expected property value.
329+
If the property is already generated, return from the bucket directly.
330+
If user explicitly set the property, return it directly.
331+
Otherwise, generate the expected property as a class private property with prefix "_".
332+
333+
Args:
334+
name: the name of target property.
335+
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
336+
"""
337+
if not self._is_initialized:
338+
raise RuntimeError("Please execute 'initialize' before getting any properties.")
339+
value = None
340+
if name in self._set_props_vals:
341+
value = self._set_props_vals[name]
342+
elif name in self._props_vals:
343+
value = self._props_vals[name]
344+
elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index]
345+
id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)
346+
value = self.parser[id]
347+
else:
348+
try:
349+
value = getattr(self, f"get_{name}")()
350+
except AttributeError as e:
351+
if property[BundleProperty.REQUIRED]:
352+
raise ValueError(
353+
f"unsupported property '{name}' is required in the bundle properties,"
354+
f"need to implement a method 'get_{name}' to provide the property."
355+
) from e
356+
self._props_vals[name] = value
357+
return value
358+
359+
def _set_property(self, name: str, property: dict, value: Any) -> Any:
360+
"""
361+
With specified property name and information, set value for the expected property.
362+
Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.
363+
364+
Args:
365+
name: the name of target property.
366+
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
367+
value: value to set for the property.
368+
369+
"""
370+
self._set_props_vals[name] = value
371+
self._is_initialized = False
372+
373+
229374
class ConfigWorkflow(BundleWorkflow):
230375
"""
231376
Specification for the config-based bundle workflow.
@@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow):
262407
or "infer", "inference", "eval", "evaluation" for a inference workflow,
263408
other unsupported string will raise a ValueError.
264409
default to `None` for common workflow.
265-
properties_path: the path to the JSON file of properties.
410+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
411+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
412+
properties will default to loading from "train". If `properties_path` is None, default properties
413+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
414+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
415+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
416+
For workflow_type = None : only `MetaProperties` will be loaded.
266417
override: id-value pairs to override or add the corresponding config content.
267418
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
268419
@@ -324,7 +475,6 @@ def __init__(
324475
self.parser.read_config(f=config_file)
325476
if self.meta_file is not None:
326477
self.parser.read_meta(f=self.meta_file)
327-
328478
# the rest key-values in the _args are to override config content
329479
self.parser.update(pairs=override)
330480
self.init_id = init_id
@@ -394,8 +544,23 @@ def check_properties(self) -> list[str] | None:
394544
ret.extend(wrong_props)
395545
return ret
396546

397-
def _run_expr(self, id: str, **kwargs: dict) -> Any:
398-
return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
547+
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
548+
"""
549+
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
550+
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
551+
"""
552+
ret = []
553+
if id in self.parser:
554+
# suppose all the expressions are in a list, run and reset the expressions
555+
if isinstance(self.parser[id], list):
556+
for i in range(len(self.parser[id])):
557+
sub_id = f"{id}{ID_SEP_KEY}{i}"
558+
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
559+
self.parser.ref_resolver.remove_resolved_content(sub_id)
560+
else:
561+
ret.append(self.parser.get_parsed_content(id, **kwargs))
562+
self.parser.ref_resolver.remove_resolved_content(id)
563+
return ret
399564

400565
def _get_prop_id(self, name: str, property: dict) -> Any:
401566
prop_id = property[BundlePropertyConfig.ID]

monai/networks/blocks/selfattention.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Tuple, Union
14+
from typing import Optional, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
@@ -154,10 +154,12 @@ def __init__(
154154
)
155155
self.input_size = input_size
156156

157-
def forward(self, x):
157+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
158158
"""
159159
Args:
160160
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161+
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162+
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
161163
162164
Return:
163165
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -176,7 +178,13 @@ def forward(self, x):
176178

177179
if self.use_flash_attention:
178180
x = F.scaled_dot_product_attention(
179-
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
181+
query=q,
182+
key=k,
183+
value=v,
184+
attn_mask=attn_mask,
185+
scale=self.scale,
186+
dropout_p=self.dropout_rate,
187+
is_causal=self.causal,
180188
)
181189
else:
182190
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
@@ -186,10 +194,16 @@ def forward(self, x):
186194
att_mat = self.rel_positional_embedding(x, att_mat, q)
187195

188196
if self.causal:
197+
if attn_mask is not None:
198+
raise ValueError("Causal attention does not support attention masks.")
189199
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
190200

191-
att_mat = att_mat.softmax(dim=-1)
201+
if attn_mask is not None:
202+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
203+
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
204+
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
192205

206+
att_mat = att_mat.softmax(dim=-1)
193207
if self.save_attn:
194208
# no gradients and new tensor;
195209
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

monai/networks/blocks/transformerblock.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ def __init__(
9090
use_flash_attention=use_flash_attention,
9191
)
9292

93-
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
94-
x = x + self.attn(self.norm1(x))
93+
def forward(
94+
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
95+
) -> torch.Tensor:
96+
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
9597
if self.with_cross_attention:
9698
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
9799
x = x + self.mlp(self.norm2(x))

monai/networks/nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .generator import Generator
5454
from .highresnet import HighResBlock, HighResNet
5555
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56+
from .masked_autoencoder_vit import MaskedAutoEncoderViT
5657
from .mednext import (
5758
MedNeXt,
5859
MedNext,

0 commit comments

Comments
 (0)