Skip to content

Commit ec0e3ec

Browse files
authored
Merge branch 'dev' into propose-fix-perceptual-loss-sqrt-nan
2 parents d78dc56 + 0d19a72 commit ec0e3ec

File tree

20 files changed

+192
-43
lines changed

20 files changed

+192
-43
lines changed

.github/workflows/pythonapp-min.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ jobs:
124124
strategy:
125125
fail-fast: false
126126
matrix:
127-
pytorch-version: ['2.3.1', '2.4.1', '2.5.1', 'latest']
127+
pytorch-version: ['2.4.1', '2.5.1', '2.6.0'] # FIXME: add 'latest' back once PyTorch 2.7 issues are resolved
128128
timeout-minutes: 40
129129
steps:
130130
- uses: actions/checkout@v4

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ jobs:
155155
# install the latest pytorch for testing
156156
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
157157
# fresh torch installation according to pyproject.toml
158-
python -m pip install torch>=2.3.0 torchvision
158+
python -m pip install torch>=2.4.1 torchvision
159159
- name: Check packages
160160
run: |
161161
pip uninstall monai

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-f https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp39-cp39-linux_x86_64.whl
2-
torch>=2.3.0
2+
torch>=2.4.1, <2.7.0
33
pytorch-ignite==0.4.11
44
numpy>=1.20
55
itk>=5.2

monai/bundle/scripts.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,6 @@ def _get_ngc_token(api_key, retry=0):
312312
return token
313313

314314

315-
def _get_latest_bundle_version_monaihosting(name):
316-
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
317-
if has_requests:
318-
resp = requests.get(full_url)
319-
try:
320-
resp.raise_for_status()
321-
model_info = json.loads(resp.text)
322-
return model_info["model"]["latestVersionIdStr"]
323-
except requests.exceptions.HTTPError:
324-
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325-
return get_bundle_versions(name)["latest_version"]
326-
327-
raise ValueError("NGC API requires requests package. Please install it.")
328-
329-
330315
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
331316
"""Examine if the package version is compatible with the MONAI version in the metadata."""
332317
version_dict = get_versions()
@@ -430,7 +415,7 @@ def _get_latest_bundle_version(
430415
name = _add_ngc_prefix(name)
431416
return _get_latest_bundle_version_ngc(name)
432417
elif source == "monaihosting":
433-
return _get_latest_bundle_version_monaihosting(name)
418+
return get_bundle_versions(name, repo="Project-MONAI/model-zoo", tag="dev")["latest_version"]
434419
elif source == "ngc_private":
435420
headers = kwargs.pop("headers", {})
436421
name = _add_ngc_prefix(name)

monai/networks/blocks/selfattention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def __init__(
101101

102102
self.num_heads = num_heads
103103
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
104-
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
104+
self.out_proj: Union[nn.Linear, nn.Identity]
105+
if include_fc:
106+
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
107+
else:
108+
self.out_proj = nn.Identity()
105109

106110
self.qkv: Union[nn.Linear, nn.Identity]
107111
self.to_q: Union[nn.Linear, nn.Identity]

monai/networks/nets/diffusion_model_unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,9 +1847,9 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
18471847
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
18481848

18491849
# projection
1850-
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1851-
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
1852-
1850+
if f"{block}.attn.out_proj.weight" in new_state_dict and f"{block}.attn.out_proj.bias" in new_state_dict:
1851+
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1852+
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
18531853
# fix the cross attention blocks
18541854
cross_attention_blocks = [
18551855
k.replace(".out_proj.weight", "")

monai/networks/schedulers/ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def step(
238238
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
239239

240240
# 6. Add noise
241-
variance = 0
241+
variance: torch.Tensor = torch.tensor(0)
242242
if timestep > 0:
243243
noise = torch.randn(
244244
model_output.size(),

monai/transforms/inverse.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import threading
1415
import warnings
1516
from collections.abc import Hashable, Mapping
1617
from contextlib import contextmanager
@@ -66,15 +67,41 @@ class TraceableTransform(Transform):
6667
The information in the stack of applied transforms must be compatible with the
6768
default collate, by only storing strings, numbers and arrays.
6869
69-
`tracing` could be enabled by `self.set_tracing` or setting
70+
`tracing` could be enabled by assigning to `self.tracing` or setting
7071
`MONAI_TRACE_TRANSFORM` when initializing the class.
7172
"""
7273

73-
tracing = MONAIEnvVars.trace_transform() != "0"
74+
def _init_trace_threadlocal(self):
75+
"""Create a `_tracing` instance member to store the thread-local tracing state value."""
76+
# needed since this class is meant to be a trait with no constructor
77+
if not hasattr(self, "_tracing"):
78+
self._tracing = threading.local()
79+
80+
# This is True while the above initialising _tracing is False when this is
81+
# called from a different thread than the one initialising _tracing.
82+
if not hasattr(self._tracing, "value"):
83+
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
84+
85+
def __getstate__(self):
86+
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
87+
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
88+
_slots = {k: getattr(self, k) for k in getattr(self, "__slots__", [])}
89+
_dict.pop("_tracing", None) # remove tracing
90+
return _dict if len(_slots) == 0 else (_dict, _slots)
91+
92+
@property
93+
def tracing(self) -> bool:
94+
"""
95+
Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
96+
"""
97+
self._init_trace_threadlocal()
98+
return bool(self._tracing.value)
7499

75-
def set_tracing(self, tracing: bool) -> None:
76-
"""Set whether to trace transforms."""
77-
self.tracing = tracing
100+
@tracing.setter
101+
def tracing(self, val: bool):
102+
"""Sets the thread-local tracing state to `val`."""
103+
self._init_trace_threadlocal()
104+
self._tracing.value = val
78105

79106
@staticmethod
80107
def trace_key(key: Hashable = None):
@@ -291,7 +318,7 @@ def check_transforms_match(self, transform: Mapping) -> None:
291318

292319
def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
293320
"""
294-
Get most recent transform for the stack.
321+
Get most recent matching transform for the current class from the sequence of applied operations.
295322
296323
Args:
297324
data: dictionary of data or `MetaTensor`.
@@ -316,9 +343,14 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
316343
all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations())
317344
else:
318345
raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.")
346+
347+
if not all_transforms:
348+
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
349+
319350
if check:
320351
self.check_transforms_match(all_transforms[-1])
321-
return all_transforms.pop() if pop else all_transforms[-1]
352+
353+
return all_transforms.pop(-1) if pop else all_transforms[-1]
322354

323355
def pop_transform(self, data, key: Hashable = None, check: bool = True):
324356
"""

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
requires = [
33
"wheel",
44
"setuptools",
5-
"torch>=2.3.0",
5+
"torch>=2.4.1, <2.7.0",
66
"ninja",
77
"packaging"
88
]
99

1010
[tool.black]
1111
line-length = 120
12-
target-version = ['py38', 'py39', 'py310']
12+
target-version = ['py39', 'py310', 'py311', 'py312']
1313
include = '\.pyi?$'
1414
exclude = '''
1515
(

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mccabe
1717
pep8-naming
1818
pycodestyle
1919
pyflakes
20-
black>=22.12
20+
black>=25.1.0
2121
isort>=5.1, <6.0
2222
ruff
2323
pytype>=2020.6.1; platform_system != "Windows"
@@ -29,6 +29,7 @@ torchvision
2929
psutil
3030
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
3131
openslide-python
32+
openslide-bin
3233
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
3334
tifffile; platform_system == "Linux" or platform_system == "Darwin"
3435
pandas

0 commit comments

Comments
 (0)