From fd4145a6338c2f5dd66bf13f5757fa86d47c0a09 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 11:34:58 +0000 Subject: [PATCH 01/33] Add content-based selection and set as fusion default --- shimmer/__init__.py | 2 + shimmer/modules/global_workspace.py | 18 ++++- shimmer/modules/selection.py | 111 +++++++++++++++++++++++++++- 3 files changed, 126 insertions(+), 5 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 5386208..7b87148 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -40,6 +40,7 @@ combine_loss, ) from shimmer.modules.selection import ( + ContentQ0SharedKeysSingleStep, RandomSelection, SelectionBase, SingleDomainSelection, @@ -103,6 +104,7 @@ "RandomSelection", "SelectionBase", "SingleDomainSelection", + "ContentQ0SharedKeysSingleStep", "DomainDesc", "RepeatedDataset", "ShimmerDataset", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index e2b6b9f..a58d625 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -29,6 +29,7 @@ LossCoefs, ) from shimmer.modules.selection import ( + ContentQ0SharedKeysSingleStep, RandomSelection, SelectionBase, SingleDomainSelection, @@ -706,7 +707,9 @@ def __init__( ) -class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): +class GlobalWorkspaceFusion( + GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses] +): """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. This is used to simplify a Global Workspace instanciation and only overrides the @@ -721,6 +724,7 @@ def __init__( workspace_dim: int, loss_coefs: BroadcastLossCoefs | Mapping[str, float], selection_temperature: float = 0.2, + selection_mod: SelectionBase | None = None, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -747,8 +751,11 @@ def __init__( workspace_dim (`int`): dimension of the GW. loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the losses. - selection_temperature (`float`): temperature value for the RandomSelection - module. + selection_temperature (`float`): legacy temperature argument kept for + compatibility; ignored unless a custom `selection_mod` uses it. + selection_mod (`SelectionBase | None`): optional custom selection module. + If None (default), uses `ContentQ0SharedKeysSingleStep` with default + toggles. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -772,7 +779,10 @@ def __init__( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale ) - selection_mod = RandomSelection(selection_temperature) + if selection_mod is None: + selection_mod = ContentQ0SharedKeysSingleStep( + gw_dim=workspace_dim, domain_names=domain_mods.keys() + ) loss_mod = GWLosses( gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss ) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index ac03bdd..c5d1ff7 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import Dict import torch import torch.nn as nn @@ -149,7 +150,115 @@ def _calculate_attention_dict( attention_dict = { domain: attention_scores[:, i] for i, domain in enumerate(domains) } - return attention_dict + return attention_dict + + +class ContentQ0SharedKeysSingleStep(SelectionBase): + """ + Content-based single-step attention over GW latents with configurable toggles. + + Design: + - Query is the mean of available GW latents (content-q0 seed) + - Single-step dot-product attention over domains (no refinement loop) + - Optional per-domain keys + + Toggles: + - per_domain_keys: use per-domain key projections instead of a shared one + - stopgrad: detach GW latents before computing keys/query + """ + + def __init__( + self, + gw_dim: int, + domain_names: Iterable[str], + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + ): + super().__init__() + self.gw_dim = int(gw_dim) + self.head_size = int(head_size) + self.domain_names = list(domain_names) + + # Toggles + self.per_domain_keys = bool(per_domain_keys) + self.stopgrad = bool(stopgrad) + + # Projections + self.query_layer = nn.Linear(self.gw_dim, self.head_size) + self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers = nn.ModuleDict( + {d: nn.Linear(self.gw_dim, self.head_size) for d in self.domain_names} + ) + + @staticmethod + def _calc_attention( + keys: Dict[str, torch.Tensor], + query: torch.Tensor, + order: Iterable[str], + ) -> dict[str, torch.Tensor]: + """ + Compute attention over domains. + + Args: + keys: mapping of domain -> key tensor (B, H) + query: query tensor (B, H) + order: iterable of domain names to fix output ordering + + Returns: + dict[str, torch.Tensor]: per-domain attention scores that sum to 1. + """ + names = [d for d in order if d in keys] + if not names: + raise ValueError("ContentQ0SharedKeysSingleStep: no keys provided.") + + logits = torch.stack( + [(keys[d] * query).sum(dim=1) for d in names], dim=1 + ) # (B, D) + + probs = torch.softmax(logits, dim=1) + + return {d: probs[:, i] for i, d in enumerate(names)} + + def forward(self, gw_latents: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Args: + gw_latents: mapping from domain name to GW latent (B, gw_dim) + + Returns: + dict[str, torch.Tensor]: per-domain attention weights. + """ + present = [d for d in self.domain_names if d in gw_latents] + if not present: + raise ValueError( + "ContentQ0SharedKeysSingleStep: no known domains present in gw_latents." + ) + + if self.stopgrad: + gw_latents = {d: t.detach() for d, t in gw_latents.items() if d in present} + else: + gw_latents = {d: gw_latents[d] for d in present} + + if self.per_domain_keys: + keys = {d: self.per_key_layers[d](gw_latents[d]) for d in present} + else: + proj = self.shared_key_layer + keys = {d: proj(gw_latents[d]) for d in present} + + stacked = torch.stack([gw_latents[d] for d in present], dim=0) # (D, B, F) + query = self.query_layer(stacked.mean(0)) # (B, H) + + return self._calc_attention( + keys=keys, + query=query, + order=self.domain_names, + ) + + def __call__( + self, encodings: LatentsDomainGroupT, gw_latents: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # The first argument is ignored for compatibility with SelectionBase signature. + return self.forward(gw_latents) class RandomSelection(SelectionBase): From 2039528f8f30e74b306e191f38b88c9d823f8f7d Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 11:56:10 +0000 Subject: [PATCH 02/33] Add helper to attach learned attention, keep fusion default random --- shimmer/modules/global_workspace.py | 35 +++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a58d625..4196abd 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -290,6 +290,30 @@ def workspace_dim(self) -> int: """Dimension of the GW.""" return self.gw_mod.workspace_dim + def init_learned_attention( + self, + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + ) -> ContentQ0SharedKeysSingleStep: + """ + Initialize and attach a learned content-based attention module. + + This replaces `self.selection_mod` with a + `ContentQ0SharedKeysSingleStep` configured for the current workspace + (uses `workspace_dim` and domain names from `domain_mods`), ensuring its + parameters are tracked by Lightning/torch. + """ + selection = ContentQ0SharedKeysSingleStep( + gw_dim=self.workspace_dim, + domain_names=self.domain_mods.keys(), + head_size=head_size, + per_domain_keys=per_domain_keys, + stopgrad=stopgrad, + ) + self.selection_mod = selection + return selection + def encode_and_fuse( self, x: LatentsDomainGroupsT, selection_module: SelectionBase ) -> dict[frozenset[str], torch.Tensor]: @@ -751,11 +775,10 @@ def __init__( workspace_dim (`int`): dimension of the GW. loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the losses. - selection_temperature (`float`): legacy temperature argument kept for - compatibility; ignored unless a custom `selection_mod` uses it. + selection_temperature (`float`): temperature value for the RandomSelection + module (default selection). selection_mod (`SelectionBase | None`): optional custom selection module. - If None (default), uses `ContentQ0SharedKeysSingleStep` with default - toggles. + If None (default), uses `RandomSelection`. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -780,9 +803,7 @@ def __init__( ) if selection_mod is None: - selection_mod = ContentQ0SharedKeysSingleStep( - gw_dim=workspace_dim, domain_names=domain_mods.keys() - ) + selection_mod = RandomSelection(selection_temperature) loss_mod = GWLosses( gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss ) From 6930d1d2b61670cefcf53e950744b66118da98dd Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 12:03:32 +0000 Subject: [PATCH 03/33] Move learned attention helper to fusion class --- shimmer/modules/global_workspace.py | 48 ++++++++++++++--------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 4196abd..ac6485a 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -290,30 +290,6 @@ def workspace_dim(self) -> int: """Dimension of the GW.""" return self.gw_mod.workspace_dim - def init_learned_attention( - self, - head_size: int = 64, - per_domain_keys: bool = False, - stopgrad: bool = True, - ) -> ContentQ0SharedKeysSingleStep: - """ - Initialize and attach a learned content-based attention module. - - This replaces `self.selection_mod` with a - `ContentQ0SharedKeysSingleStep` configured for the current workspace - (uses `workspace_dim` and domain names from `domain_mods`), ensuring its - parameters are tracked by Lightning/torch. - """ - selection = ContentQ0SharedKeysSingleStep( - gw_dim=self.workspace_dim, - domain_names=self.domain_mods.keys(), - head_size=head_size, - per_domain_keys=per_domain_keys, - stopgrad=stopgrad, - ) - self.selection_mod = selection - return selection - def encode_and_fuse( self, x: LatentsDomainGroupsT, selection_module: SelectionBase ) -> dict[frozenset[str], torch.Tensor]: @@ -730,6 +706,30 @@ def __init__( scheduler, ) + def init_learned_attention( + self, + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + ) -> ContentQ0SharedKeysSingleStep: + """ + Initialize and attach a learned content-based attention module. + + This replaces `self.selection_mod` with a + `ContentQ0SharedKeysSingleStep` configured for the current workspace + (uses `workspace_dim` and domain names from `domain_mods`), ensuring its + parameters are tracked by Lightning/torch. + """ + selection = ContentQ0SharedKeysSingleStep( + gw_dim=self.workspace_dim, + domain_names=self.domain_mods.keys(), + head_size=head_size, + per_domain_keys=per_domain_keys, + stopgrad=stopgrad, + ) + self.selection_mod = selection + return selection + class GlobalWorkspaceFusion( GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses] From a8004b125c58533189e10d467f208f7aaf205b93 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 12:17:47 +0000 Subject: [PATCH 04/33] Rebalance broadcast loss coefs and aggregates --- shimmer/modules/losses.py | 26 +++++++++----------------- tests/test_broadcast.py | 8 ++------ 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index d934307..997422f 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -317,11 +317,8 @@ class BroadcastLossCoefs(TypedDict, total=False): contrastives: float """Contrastive loss coefficient.""" - fused: float - """fused loss coefficient (encode multiple domains and decode to one of them).""" - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles are always one-to-one""" + """demi_cycles loss coefficient. Demi-cycles aggregate fused cases too.""" cycles: float """cycles loss coefficient. Cycles can be many-to-one""" @@ -524,19 +521,19 @@ def broadcast_loss( raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ - Computes broadcast loss including demi-cycle, cycle, and translation losses. + Computes broadcast loss including demi-cycle (with fused), cycle, and translation + losses. This return multiple metrics: * `demi_cycles` * `cycles` * `translations` - * `fused` * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order - (e.g. "from_{t,v}_to_t_loss"). - * `from_{start_group}_to_{domain}_{metric}` with - additional metrics provided by the domain_mod's - `compute_broadcast_loss` output + (e.g. "from_{t,v}_to_t_loss"). Note: fused cases are aggregated into + `demi_cycles`. + * `from_{start_group}_to_{domain}_{metric}` with additional metrics provided by + the domain_mod's `compute_broadcast_loss` output * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` where `{start_group}`, `{target_group}` and `{case_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order @@ -563,7 +560,6 @@ def broadcast_loss( demi_cycle_losses: list[str] = [] cycle_losses: list[str] = [] translation_losses: list[str] = [] - fused_losses: list[str] = [] for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) @@ -616,8 +612,8 @@ def broadcast_loss( demi_cycle_losses.append(loss_label + "_loss") elif domain not in selected_latents: translation_losses.append(loss_label + "_loss") - else: # fused loss - fused_losses.append(loss_label + "_loss") + else: # fused loss counts toward demi_cycles aggregate + demi_cycle_losses.append(loss_label + "_loss") if num_active_domains < num_total_domains: inverse_selected_latents = { @@ -674,10 +670,6 @@ def broadcast_loss( metrics["translations"] = torch.mean( torch.stack([losses[loss_name] for loss_name in translation_losses]) ) - if fused_losses: - metrics["fused"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in fused_losses]) - ) metrics.update(losses) return metrics diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 012e169..fa6bff9 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -36,7 +36,6 @@ def test_broadcast_loss(): gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} workspace_dim = 10 loss_coefs: BroadcastLossCoefs = { - "fused": 1.0, "cycles": 1.0, "demi_cycles": 1.0, "translations": 1.0, @@ -68,11 +67,8 @@ def test_broadcast_loss(): # Test broadcast_loss with the corrected structure output = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) - er_msg = "Demi-cycle, cycle, fused and translation metrics should be in the output." - assert all( - metric in output - for metric in ["demi_cycles", "cycles", "translations", "fused"] - ), er_msg + er_msg = "Demi-cycle, cycle and translation metrics should be in the output." + assert all(metric in output for metric in ["demi_cycles", "cycles", "translations"]) er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." assert all( From e6461586cf4ea673489b416b14b24f3ee975a38e Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 12:45:12 +0000 Subject: [PATCH 05/33] Fix indentation in selection helper causing test import failure --- shimmer/modules/selection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index c5d1ff7..e1cabdd 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -150,7 +150,7 @@ def _calculate_attention_dict( attention_dict = { domain: attention_scores[:, i] for i, domain in enumerate(domains) } - return attention_dict + return attention_dict class ContentQ0SharedKeysSingleStep(SelectionBase): From 4f09afc746757fd7e479884fd3d5506e21737125 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 13:00:29 +0000 Subject: [PATCH 06/33] Run ruff cleanups in selection --- shimmer/modules/selection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index e1cabdd..97112ff 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Dict import torch import torch.nn as nn @@ -193,7 +192,7 @@ def __init__( @staticmethod def _calc_attention( - keys: Dict[str, torch.Tensor], + keys: dict[str, torch.Tensor], query: torch.Tensor, order: Iterable[str], ) -> dict[str, torch.Tensor]: From b31049aa0d73a5f502cc0a5df33224dfccea8b24 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 13:07:02 +0000 Subject: [PATCH 07/33] Format global_workspace with ruff --- shimmer/modules/global_workspace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index ac6485a..879f486 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -731,9 +731,7 @@ def init_learned_attention( return selection -class GlobalWorkspaceFusion( - GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses] -): +class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses]): """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. This is used to simplify a Global Workspace instanciation and only overrides the From 95d054fbb7bc0a61c60b6c7c8842856a6f8a4d96 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 13:18:57 +0000 Subject: [PATCH 08/33] Align selection signatures and move learned attention helper --- shimmer/modules/global_workspace.py | 49 ++++++++++++++--------------- shimmer/modules/selection.py | 21 +++++++------ 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 879f486..c69d742 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -706,31 +706,6 @@ def __init__( scheduler, ) - def init_learned_attention( - self, - head_size: int = 64, - per_domain_keys: bool = False, - stopgrad: bool = True, - ) -> ContentQ0SharedKeysSingleStep: - """ - Initialize and attach a learned content-based attention module. - - This replaces `self.selection_mod` with a - `ContentQ0SharedKeysSingleStep` configured for the current workspace - (uses `workspace_dim` and domain names from `domain_mods`), ensuring its - parameters are tracked by Lightning/torch. - """ - selection = ContentQ0SharedKeysSingleStep( - gw_dim=self.workspace_dim, - domain_names=self.domain_mods.keys(), - head_size=head_size, - per_domain_keys=per_domain_keys, - stopgrad=stopgrad, - ) - self.selection_mod = selection - return selection - - class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses]): """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. @@ -816,6 +791,30 @@ def __init__( scheduler, ) + def init_learned_attention( + self, + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + ) -> ContentQ0SharedKeysSingleStep: + """ + Initialize and attach a learned content-based attention module. + + This replaces `self.selection_mod` with a + `ContentQ0SharedKeysSingleStep` configured for the current workspace + (uses `workspace_dim` and domain names from `domain_mods`), ensuring its + parameters are tracked by Lightning/torch. + """ + selection = ContentQ0SharedKeysSingleStep( + gw_dim=self.workspace_dim, + domain_names=self.domain_mods.keys(), + head_size=head_size, + per_domain_keys=per_domain_keys, + stopgrad=stopgrad, + ) + self.selection_mod = selection + return selection + def pretrained_global_workspace( checkpoint_path: str | Path, diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index 97112ff..f8d895a 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Iterable, Mapping import torch import torch.nn as nn @@ -219,14 +219,23 @@ def _calc_attention( return {d: probs[:, i] for i, d in enumerate(names)} - def forward(self, gw_latents: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def forward( + self, + domains: LatentsDomainGroupT, + encodings_pre_fusion: LatentsDomainGroupT | None = None, + ) -> dict[str, torch.Tensor]: """ Args: - gw_latents: mapping from domain name to GW latent (B, gw_dim) + domains: mapping from domain name to GW latent (B, gw_dim) + encodings_pre_fusion: unused; kept for `SelectionBase` compatibility. Returns: dict[str, torch.Tensor]: per-domain attention weights. """ + del encodings_pre_fusion # unused + + gw_latents: Mapping[str, torch.Tensor] = domains + present = [d for d in self.domain_names if d in gw_latents] if not present: raise ValueError( @@ -253,12 +262,6 @@ def forward(self, gw_latents: dict[str, torch.Tensor]) -> dict[str, torch.Tensor order=self.domain_names, ) - def __call__( - self, encodings: LatentsDomainGroupT, gw_latents: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: - # The first argument is ignored for compatibility with SelectionBase signature. - return self.forward(gw_latents) - class RandomSelection(SelectionBase): """ From 7bca399590290cb673f3847d62fbe931f9f171a2 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Tue, 25 Nov 2025 13:22:35 +0000 Subject: [PATCH 09/33] Apply ruff formatting to global workspace --- shimmer/modules/global_workspace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index c69d742..46a1430 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -706,6 +706,7 @@ def __init__( scheduler, ) + class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses]): """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. From 95b3b628a773f17e57c59b1ced347bde682d7432 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Fri, 28 Nov 2025 10:20:13 +0000 Subject: [PATCH 10/33] Update broadcasts docstrings --- shimmer/modules/global_workspace.py | 2 +- shimmer/modules/gw_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 46a1430..1178127 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -66,7 +66,7 @@ class GWPredictionsBase(TypedDict): broadcasts: dict[frozenset[str], dict[str, torch.Tensor]] """ broadcasts predictions of the model for each domain. It contains demi-cycles, - translations, and fused. + translations. """ cycles: dict[frozenset[str], dict[str, torch.Tensor]] diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index a7c039b..c21a3a5 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -217,7 +217,7 @@ class GWModulePrediction(TypedDict): broadcasts: dict[str, torch.Tensor] """ broadcasts predictions of the model for each domain. It contains demi-cycles, - translations, and fused. + translations. """ cycles: dict[str, torch.Tensor] From cc9dbc368419d65845c3db51366c1250672f777b Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Fri, 28 Nov 2025 10:26:42 +0000 Subject: [PATCH 11/33] Remove fused loss handler --- shimmer/modules/domain.py | 19 ------------------- shimmer/modules/losses.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index c909f01..4d23d14 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -185,25 +185,6 @@ def compute_tr_loss( """ return self.compute_loss(pred, target, raw_target) - def compute_fused_loss( - self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any - ) -> LossOutput | None: - """ - Computes the loss for fused (fusion). Override if the fused loss is - different that the generic loss. - - Args: - pred (`torch.Tensor`): prediction of the model - target (`torch.Tensor`): target tensor - raw_target (`Any`): raw data from the input - Results: - `LossOutput | None`: LossOuput with training loss and additional metrics. - If `None` is returned, this loss will be ignored and will not - participate in the total loss; it can be used to deactivate - fused loss for this domain. - """ - return self.compute_loss(pred, target, raw_target) - def compute_domain_loss(self, domain: Any) -> LossOutput | None: """ Compute the unimodal domain loss. diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 997422f..b546bf4 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -594,7 +594,7 @@ def broadcast_loss( elif domain not in selected_latents: loss_fn = domain_mods[domain].compute_tr_loss else: - loss_fn = domain_mods[domain].compute_fused_loss + loss_fn = domain_mods[domain].compute_dcy_loss loss_output = loss_fn( pred, ground_truth, raw_data[group_domains][domain] From 1b2fd83ce51ba93dbfb2330f5f0de28cb2fad6c7 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Mon, 8 Dec 2025 13:23:59 +0000 Subject: [PATCH 12/33] Rename learned attention module --- shimmer/__init__.py | 4 ++-- shimmer/modules/global_workspace.py | 13 ++++++------- shimmer/modules/selection.py | 6 +++--- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 7b87148..5e8355d 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -40,7 +40,7 @@ combine_loss, ) from shimmer.modules.selection import ( - ContentQ0SharedKeysSingleStep, + LearnedAttention, RandomSelection, SelectionBase, SingleDomainSelection, @@ -104,7 +104,7 @@ "RandomSelection", "SelectionBase", "SingleDomainSelection", - "ContentQ0SharedKeysSingleStep", + "LearnedAttention", "DomainDesc", "RepeatedDataset", "ShimmerDataset", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 1178127..64a18ee 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -29,7 +29,7 @@ LossCoefs, ) from shimmer.modules.selection import ( - ContentQ0SharedKeysSingleStep, + LearnedAttention, RandomSelection, SelectionBase, SingleDomainSelection, @@ -797,16 +797,15 @@ def init_learned_attention( head_size: int = 64, per_domain_keys: bool = False, stopgrad: bool = True, - ) -> ContentQ0SharedKeysSingleStep: + ) -> LearnedAttention: """ Initialize and attach a learned content-based attention module. - This replaces `self.selection_mod` with a - `ContentQ0SharedKeysSingleStep` configured for the current workspace - (uses `workspace_dim` and domain names from `domain_mods`), ensuring its - parameters are tracked by Lightning/torch. + This replaces `self.selection_mod` with a `LearnedAttention` configured for + the current workspace (uses `workspace_dim` and domain names from + `domain_mods`), ensuring its parameters are tracked by Lightning/torch. """ - selection = ContentQ0SharedKeysSingleStep( + selection = LearnedAttention( gw_dim=self.workspace_dim, domain_names=self.domain_mods.keys(), head_size=head_size, diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index f8d895a..cc7531c 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -152,7 +152,7 @@ def _calculate_attention_dict( return attention_dict -class ContentQ0SharedKeysSingleStep(SelectionBase): +class LearnedAttention(SelectionBase): """ Content-based single-step attention over GW latents with configurable toggles. @@ -209,7 +209,7 @@ def _calc_attention( """ names = [d for d in order if d in keys] if not names: - raise ValueError("ContentQ0SharedKeysSingleStep: no keys provided.") + raise ValueError("LearnedAttention: no keys provided.") logits = torch.stack( [(keys[d] * query).sum(dim=1) for d in names], dim=1 @@ -239,7 +239,7 @@ def forward( present = [d for d in self.domain_names if d in gw_latents] if not present: raise ValueError( - "ContentQ0SharedKeysSingleStep: no known domains present in gw_latents." + "LearnedAttention: no known domains present in gw_latents." ) if self.stopgrad: From c16fd1d63629758a62bd0a8aa8b5b747c19d9909 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Mon, 8 Dec 2025 14:14:31 +0000 Subject: [PATCH 13/33] Add LearnedAttention coverage --- tests/test_learned_attention.py | 49 +++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/test_learned_attention.py diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py new file mode 100644 index 0000000..2b78ca6 --- /dev/null +++ b/tests/test_learned_attention.py @@ -0,0 +1,49 @@ +import torch + +from shimmer.modules.selection import LearnedAttention + + +def _make_latents(batch_size: int, dim: int) -> dict[str, torch.Tensor]: + return { + "a": torch.randn(batch_size, dim, requires_grad=True), + "b": torch.randn(batch_size, dim, requires_grad=True), + } + + +def test_learned_attention_probs_sum_to_one() -> None: + selector = LearnedAttention(gw_dim=4, domain_names=["a", "b"], head_size=2) + latents = _make_latents(batch_size=8, dim=4) + + weights = selector(latents, encodings_pre_fusion=None) + + for domain in ["a", "b"]: + assert weights[domain].shape == (8,) + + stacked = torch.stack([weights["a"], weights["b"]], dim=1) + assert torch.allclose(stacked.sum(dim=1), torch.ones(8)) + + +def test_learned_attention_stopgrad_toggle() -> None: + base_latents = _make_latents(batch_size=4, dim=6) + + frozen_latents = { + k: v.detach().clone().requires_grad_(True) for k, v in base_latents.items() + } + frozen_selector = LearnedAttention( + gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=True + ) + frozen_weights = frozen_selector(frozen_latents, encodings_pre_fusion=None) + torch.stack(list(frozen_weights.values())).sum().backward() + assert frozen_latents["a"].grad is None + assert frozen_latents["b"].grad is None + + train_latents = { + k: v.detach().clone().requires_grad_(True) for k, v in base_latents.items() + } + trainable_selector = LearnedAttention( + gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=False + ) + trainable_weights = trainable_selector(train_latents, encodings_pre_fusion=None) + torch.stack(list(trainable_weights.values())).sum().backward() + assert train_latents["a"].grad is not None + assert train_latents["b"].grad is not None From 4db5e55ce10f772c8afe5c898565d1c767e6ab3a Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Mon, 8 Dec 2025 14:21:50 +0000 Subject: [PATCH 14/33] Fix LearnedAttention test types for mypy --- tests/test_learned_attention.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py index 2b78ca6..3144915 100644 --- a/tests/test_learned_attention.py +++ b/tests/test_learned_attention.py @@ -14,7 +14,7 @@ def test_learned_attention_probs_sum_to_one() -> None: selector = LearnedAttention(gw_dim=4, domain_names=["a", "b"], head_size=2) latents = _make_latents(batch_size=8, dim=4) - weights = selector(latents, encodings_pre_fusion=None) + weights = selector(latents, encodings_pre_fusion=latents) for domain in ["a", "b"]: assert weights[domain].shape == (8,) @@ -32,7 +32,9 @@ def test_learned_attention_stopgrad_toggle() -> None: frozen_selector = LearnedAttention( gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=True ) - frozen_weights = frozen_selector(frozen_latents, encodings_pre_fusion=None) + frozen_weights = frozen_selector( + frozen_latents, encodings_pre_fusion=frozen_latents + ) torch.stack(list(frozen_weights.values())).sum().backward() assert frozen_latents["a"].grad is None assert frozen_latents["b"].grad is None @@ -43,7 +45,9 @@ def test_learned_attention_stopgrad_toggle() -> None: trainable_selector = LearnedAttention( gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=False ) - trainable_weights = trainable_selector(train_latents, encodings_pre_fusion=None) + trainable_weights = trainable_selector( + train_latents, encodings_pre_fusion=train_latents + ) torch.stack(list(trainable_weights.values())).sum().backward() assert train_latents["a"].grad is not None assert train_latents["b"].grad is not None From bc5323046b34b9ff7e8bffcf68b3dd04ae2f3432 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 11:59:25 +0000 Subject: [PATCH 15/33] Split cycle loss from broadcast path --- shimmer/data/dataset.py | 6 +- shimmer/modules/losses.py | 103 +++++++++++++++++++------- tests/test_broadcast.py | 15 +++- tests/test_kq_onepass_attention.py | 12 +-- tests/test_query_attention.py | 12 +-- tests/test_random_attention.py | 18 ++--- tests/test_single_domain_selection.py | 12 +-- 7 files changed, 119 insertions(+), 59 deletions(-) diff --git a/shimmer/data/dataset.py b/shimmer/data/dataset.py index 8f471b8..2bd587b 100644 --- a/shimmer/data/dataset.py +++ b/shimmer/data/dataset.py @@ -102,9 +102,9 @@ def __init__( ) self.dataset_size = min_length if self.max_size is not None: - assert ( - self.max_size <= self.dataset_size - ), "Max sizes can only be lower than actual size." + assert self.max_size <= self.dataset_size, ( + "Max sizes can only be lower than actual size." + ) self.dataset_size = self.max_size def __len__(self) -> int: diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index b546bf4..14aedc7 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -357,6 +357,28 @@ def combine_loss( return loss +class CycleCase(TypedDict): + """Container for precomputed cycle data to avoid recomputation.""" + + loss_label: str + domain_name: str + prediction: torch.Tensor + target: torch.Tensor + raw_target: object + + +class BroadcastLossResult(TypedDict): + """ + Broadcast loss output without cycle computation. + + `metrics` contains demi-cycle/translation metrics and per-example losses. + `cycle_cases` holds precomputed tensors/labels for later cycle loss computation. + """ + + metrics: dict[str, torch.Tensor] + cycle_cases: list[CycleCase] + + class GWLosses2Domains(GWLossesBase): """ Implementation of `GWLossesBase` used for `GWModule`. @@ -519,10 +541,10 @@ def broadcast_loss( domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT, -) -> dict[str, torch.Tensor]: +) -> BroadcastLossResult: """ - Computes broadcast loss including demi-cycle (with fused), cycle, and translation - losses. + Computes broadcast demi-cycle (with fused) and translation losses, and prepares + precomputed artifacts for cycle losses. This return multiple metrics: * `demi_cycles` @@ -552,14 +574,14 @@ def broadcast_loss( raw_data (`RawDomainGroupsT`): raw input data Returns: - A dictionary with the total loss and additional metrics. + `BroadcastLossResult`: demi/translation metrics plus precomputed cycle data. """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} demi_cycle_losses: list[str] = [] - cycle_losses: list[str] = [] translation_losses: list[str] = [] + cycle_cases: list[CycleCase] = [] for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) @@ -636,42 +658,67 @@ def broadcast_loss( ) for domain in selected_latents: - re_ground_truth = latents[domain] - re_loss_output = domain_mods[domain].compute_cy_loss( - re_decoded_latents[domain], - re_ground_truth, - raw_data[group_domains][domain], - ) - if re_loss_output is None: - continue loss_label = ( f"from_{selected_group_label}_" f"through_{inverse_selected_group_label}_to_{domain}_" f"case_{group_name}" ) - losses[loss_label + "_loss"] = re_loss_output.loss - metrics.update( - { - f"{loss_label}_{k}": v - for k, v in re_loss_output.metrics.items() - } + cycle_cases.append( + CycleCase( + loss_label=loss_label, + domain_name=domain, + prediction=re_decoded_latents[domain], + target=latents[domain], + raw_target=raw_data[group_domains][domain], + ) ) - cycle_losses.append(loss_label + "_loss") if demi_cycle_losses: metrics["demi_cycles"] = torch.mean( torch.stack([losses[loss_name] for loss_name in demi_cycle_losses]) ) - if cycle_losses: - metrics["cycles"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in cycle_losses]) - ) if translation_losses: metrics["translations"] = torch.mean( torch.stack([losses[loss_name] for loss_name in translation_losses]) ) metrics.update(losses) + return BroadcastLossResult(metrics=metrics, cycle_cases=cycle_cases) + + +def cycle_loss_from_broadcast( + domain_mods: Mapping[str, DomainModule], + cycle_cases: list[CycleCase], +) -> dict[str, torch.Tensor]: + """ + Computes cycle losses from precomputed broadcast artifacts. + + Args: + domain_mods: domain modules used to compute the losses. + cycle_cases: precomputed cycle data produced by `broadcast_loss`. + + Returns: + Metrics dict containing per-case losses/metrics and aggregate `cycles`. + """ + metrics: dict[str, torch.Tensor] = {} + cycle_losses: list[str] = [] + + for case in cycle_cases: + loss_output = domain_mods[case["domain_name"]].compute_cy_loss( + case["prediction"], case["target"], case["raw_target"] + ) + if loss_output is None: + continue + loss_name = case["loss_label"] + metrics[loss_name + "_loss"] = loss_output.loss + metrics.update({f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}) + cycle_losses.append(loss_name + "_loss") + + if cycle_losses: + metrics["cycles"] = torch.mean( + torch.stack([metrics[loss_name] for loss_name in cycle_losses]) + ) + return metrics @@ -722,7 +769,7 @@ def contrastive_loss( def broadcast_loss( self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT - ) -> dict[str, torch.Tensor]: + ) -> BroadcastLossResult: return broadcast_loss( self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) @@ -748,7 +795,11 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents, raw_data)) + broadcast_result = self.broadcast_loss(domain_latents, raw_data) + metrics.update(broadcast_result["metrics"]) + metrics.update( + cycle_loss_from_broadcast(self.domain_mods, broadcast_result["cycle_cases"]) + ) loss = combine_loss(metrics, self.loss_coefs) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index fa6bff9..ddbba22 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -65,14 +65,23 @@ def test_broadcast_loss(): } # Test broadcast_loss with the corrected structure - output = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) + result = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) + assert "metrics" in result and "cycle_cases" in result + # Cycle metrics are computed in step(), not within broadcast_loss + metrics = result["metrics"] + assert all(metric in metrics for metric in ["demi_cycles", "translations"]) + # Cycle metrics should appear after the step call + step_output = gw_fusion.loss_mod.step(latent_domains, latent_domains, mode="train") er_msg = "Demi-cycle, cycle and translation metrics should be in the output." - assert all(metric in output for metric in ["demi_cycles", "cycles", "translations"]) + assert all( + metric in step_output.metrics + for metric in ["demi_cycles", "cycles", "translations"] + ) er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." assert all( (loss.dim() == 0 or (loss.dim() == 1 and loss.size(0) == 1)) - for key, loss in output.items() + for key, loss in metrics.items() if key.endswith("_loss") ), er_msg diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py index 74b8b8e..84f0f19 100644 --- a/tests/test_kq_onepass_attention.py +++ b/tests/test_kq_onepass_attention.py @@ -18,9 +18,9 @@ def test_single_domain(): attention_scores = attention(single_domain_input, encodings_pre_fusion) expected_scores = torch.ones(batch_size, 1) - assert torch.allclose( - attention_scores["v_latents"], expected_scores - ), "Attention scores for single domain should be all 1s" + assert torch.allclose(attention_scores["v_latents"], expected_scores), ( + "Attention scores for single domain should be all 1s" + ) def test_multiple_domains_sumis1(): @@ -50,6 +50,6 @@ def test_multiple_domains_sumis1(): expected_sum = torch.ones(batch_size) - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of attention scores across domains should be 1" + assert torch.allclose(scores_sum, expected_sum), ( + "Sum of attention scores across domains should be 1" + ) diff --git a/tests/test_query_attention.py b/tests/test_query_attention.py index 61d99bd..3161760 100644 --- a/tests/test_query_attention.py +++ b/tests/test_query_attention.py @@ -59,9 +59,9 @@ def test_single_domain(): attention_scores = attention(single_domain_input, prefusion_encodings) expected_scores = torch.ones(batch_size, 1) - assert torch.allclose( - attention_scores["v_latents"], expected_scores - ), "Attention scores for single domain should be all 1s" + assert torch.allclose(attention_scores["v_latents"], expected_scores), ( + "Attention scores for single domain should be all 1s" + ) def test_multiple_domains_sumis1(): @@ -89,9 +89,9 @@ def test_multiple_domains_sumis1(): expected_sum = torch.ones(batch_size) - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of attention scores across domains should be 1" + assert torch.allclose(scores_sum, expected_sum), ( + "Sum of attention scores across domains should be 1" + ) def test_attention_backward(): diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py index 5422fb7..8ae1e34 100644 --- a/tests/test_random_attention.py +++ b/tests/test_random_attention.py @@ -29,9 +29,9 @@ def test_multiple_domains(): expected_sum = torch.ones(batch_size) - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of selection scores across domains should be 1" + assert torch.allclose(scores_sum, expected_sum), ( + "Sum of selection scores across domains should be 1" + ) def test_three_domains(): @@ -56,9 +56,9 @@ def test_three_domains(): # Ensure that the shape of the selection scores matches the input domains for domain in three_domain_input: - assert selection_scores[domain].shape == ( - batch_size, - ), f"Scores shape mismatch for {domain}" + assert selection_scores[domain].shape == (batch_size,), ( + f"Scores shape mismatch for {domain}" + ) # Ensure the sum of attention scores across domains equals 1 scores_sum = sum( @@ -68,6 +68,6 @@ def test_three_domains(): expected_sum = torch.ones(batch_size) - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of selection scores across three domains should be 1" + assert torch.allclose(scores_sum, expected_sum), ( + "Sum of selection scores across three domains should be 1" + ) diff --git a/tests/test_single_domain_selection.py b/tests/test_single_domain_selection.py index 90b44ab..5bc21ff 100644 --- a/tests/test_single_domain_selection.py +++ b/tests/test_single_domain_selection.py @@ -27,9 +27,9 @@ def test_selection_2_domains(): selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) - assert ( - (selection["v"] + selection["t"]) == 1 - ).sum() == bs, "Everything should be selected once and only once." + assert ((selection["v"] + selection["t"]) == 1).sum() == bs, ( + "Everything should be selected once and only once." + ) def test_selection_3_domains(): @@ -50,6 +50,6 @@ def test_selection_3_domains(): selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) - assert ( - (selection["v"] + selection["t"] + selection["attr"]) == 1 - ).sum() == bs, "Everything should be selected once and only once." + assert ((selection["v"] + selection["t"] + selection["attr"]) == 1).sum() == bs, ( + "Everything should be selected once and only once." + ) From 931df0419feeacd62d9b9d00a7ed0ea130ed1b2a Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 12:40:07 +0000 Subject: [PATCH 16/33] Format code with ruff --- .gitignore | 1 + shimmer/data/dataset.py | 6 +++--- tests/test_kq_onepass_attention.py | 12 ++++++------ tests/test_query_attention.py | 12 ++++++------ tests/test_random_attention.py | 18 +++++++++--------- tests/test_single_domain_selection.py | 12 ++++++------ 6 files changed, 31 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index cf96243..9cf8580 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ cython_debug/ .rgignore .ruff_cache/ +.poetry_cache/ diff --git a/shimmer/data/dataset.py b/shimmer/data/dataset.py index 2bd587b..8f471b8 100644 --- a/shimmer/data/dataset.py +++ b/shimmer/data/dataset.py @@ -102,9 +102,9 @@ def __init__( ) self.dataset_size = min_length if self.max_size is not None: - assert self.max_size <= self.dataset_size, ( - "Max sizes can only be lower than actual size." - ) + assert ( + self.max_size <= self.dataset_size + ), "Max sizes can only be lower than actual size." self.dataset_size = self.max_size def __len__(self) -> int: diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py index 84f0f19..74b8b8e 100644 --- a/tests/test_kq_onepass_attention.py +++ b/tests/test_kq_onepass_attention.py @@ -18,9 +18,9 @@ def test_single_domain(): attention_scores = attention(single_domain_input, encodings_pre_fusion) expected_scores = torch.ones(batch_size, 1) - assert torch.allclose(attention_scores["v_latents"], expected_scores), ( - "Attention scores for single domain should be all 1s" - ) + assert torch.allclose( + attention_scores["v_latents"], expected_scores + ), "Attention scores for single domain should be all 1s" def test_multiple_domains_sumis1(): @@ -50,6 +50,6 @@ def test_multiple_domains_sumis1(): expected_sum = torch.ones(batch_size) - assert torch.allclose(scores_sum, expected_sum), ( - "Sum of attention scores across domains should be 1" - ) + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of attention scores across domains should be 1" diff --git a/tests/test_query_attention.py b/tests/test_query_attention.py index 3161760..61d99bd 100644 --- a/tests/test_query_attention.py +++ b/tests/test_query_attention.py @@ -59,9 +59,9 @@ def test_single_domain(): attention_scores = attention(single_domain_input, prefusion_encodings) expected_scores = torch.ones(batch_size, 1) - assert torch.allclose(attention_scores["v_latents"], expected_scores), ( - "Attention scores for single domain should be all 1s" - ) + assert torch.allclose( + attention_scores["v_latents"], expected_scores + ), "Attention scores for single domain should be all 1s" def test_multiple_domains_sumis1(): @@ -89,9 +89,9 @@ def test_multiple_domains_sumis1(): expected_sum = torch.ones(batch_size) - assert torch.allclose(scores_sum, expected_sum), ( - "Sum of attention scores across domains should be 1" - ) + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of attention scores across domains should be 1" def test_attention_backward(): diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py index 8ae1e34..5422fb7 100644 --- a/tests/test_random_attention.py +++ b/tests/test_random_attention.py @@ -29,9 +29,9 @@ def test_multiple_domains(): expected_sum = torch.ones(batch_size) - assert torch.allclose(scores_sum, expected_sum), ( - "Sum of selection scores across domains should be 1" - ) + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of selection scores across domains should be 1" def test_three_domains(): @@ -56,9 +56,9 @@ def test_three_domains(): # Ensure that the shape of the selection scores matches the input domains for domain in three_domain_input: - assert selection_scores[domain].shape == (batch_size,), ( - f"Scores shape mismatch for {domain}" - ) + assert selection_scores[domain].shape == ( + batch_size, + ), f"Scores shape mismatch for {domain}" # Ensure the sum of attention scores across domains equals 1 scores_sum = sum( @@ -68,6 +68,6 @@ def test_three_domains(): expected_sum = torch.ones(batch_size) - assert torch.allclose(scores_sum, expected_sum), ( - "Sum of selection scores across three domains should be 1" - ) + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of selection scores across three domains should be 1" diff --git a/tests/test_single_domain_selection.py b/tests/test_single_domain_selection.py index 5bc21ff..90b44ab 100644 --- a/tests/test_single_domain_selection.py +++ b/tests/test_single_domain_selection.py @@ -27,9 +27,9 @@ def test_selection_2_domains(): selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) - assert ((selection["v"] + selection["t"]) == 1).sum() == bs, ( - "Everything should be selected once and only once." - ) + assert ( + (selection["v"] + selection["t"]) == 1 + ).sum() == bs, "Everything should be selected once and only once." def test_selection_3_domains(): @@ -50,6 +50,6 @@ def test_selection_3_domains(): selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) - assert ((selection["v"] + selection["t"] + selection["attr"]) == 1).sum() == bs, ( - "Everything should be selected once and only once." - ) + assert ( + (selection["v"] + selection["t"] + selection["attr"]) == 1 + ).sum() == bs, "Everything should be selected once and only once." From 9f15e82ec13109c283d6d3aa28a3aa42a5d148b3 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 15:22:28 +0000 Subject: [PATCH 17/33] Move cycle reconstruction into cycle loss --- shimmer/modules/losses.py | 108 +++++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 49 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 14aedc7..b01cf65 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -358,13 +358,13 @@ def combine_loss( class CycleCase(TypedDict): - """Container for precomputed cycle data to avoid recomputation.""" + """Container for precomputed cycle inputs to avoid recomputation.""" - loss_label: str - domain_name: str - prediction: torch.Tensor - target: torch.Tensor - raw_target: object + group_name: str + selected_group_label: str + selected_latents: Mapping[str, torch.Tensor] + decoded_latents: Mapping[str, torch.Tensor] + raw_group: Mapping[str, object] class BroadcastLossResult(TypedDict): @@ -372,7 +372,7 @@ class BroadcastLossResult(TypedDict): Broadcast loss output without cycle computation. `metrics` contains demi-cycle/translation metrics and per-example losses. - `cycle_cases` holds precomputed tensors/labels for later cycle loss computation. + `cycle_cases` holds precomputed inputs for later cycle loss computation. """ metrics: dict[str, torch.Tensor] @@ -638,40 +638,15 @@ def broadcast_loss( demi_cycle_losses.append(loss_label + "_loss") if num_active_domains < num_total_domains: - inverse_selected_latents = { - domain: decoded_latents[domain] - for domain in decoded_latents - if domain not in selected_latents - } - - inverse_selected_group_label = ( - "{" + ",".join(sorted(inverse_selected_latents)) + "}" - ) - - re_encoded_latents = gw_mod.encode(inverse_selected_latents) - re_selection_scores = selection_mod( - inverse_selected_latents, re_encoded_latents - ) - re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores) - re_decoded_latents = gw_mod.decode( - re_fused_latents, domains=selected_latents.keys() - ) - - for domain in selected_latents: - loss_label = ( - f"from_{selected_group_label}_" - f"through_{inverse_selected_group_label}_to_{domain}_" - f"case_{group_name}" - ) - cycle_cases.append( - CycleCase( - loss_label=loss_label, - domain_name=domain, - prediction=re_decoded_latents[domain], - target=latents[domain], - raw_target=raw_data[group_domains][domain], - ) + cycle_cases.append( + CycleCase( + group_name=group_name, + selected_group_label=selected_group_label, + selected_latents=selected_latents, + decoded_latents=decoded_latents, + raw_group=raw_data[group_domains], ) + ) if demi_cycle_losses: metrics["demi_cycles"] = torch.mean( @@ -687,6 +662,8 @@ def broadcast_loss( def cycle_loss_from_broadcast( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], cycle_cases: list[CycleCase], ) -> dict[str, torch.Tensor]: @@ -694,6 +671,8 @@ def cycle_loss_from_broadcast( Computes cycle losses from precomputed broadcast artifacts. Args: + gw_mod: GW module used for encoding/decoding. + selection_mod: selection module used during fusion. domain_mods: domain modules used to compute the losses. cycle_cases: precomputed cycle data produced by `broadcast_loss`. @@ -704,15 +683,41 @@ def cycle_loss_from_broadcast( cycle_losses: list[str] = [] for case in cycle_cases: - loss_output = domain_mods[case["domain_name"]].compute_cy_loss( - case["prediction"], case["target"], case["raw_target"] + inverse_selected_latents = { + domain: case["decoded_latents"][domain] + for domain in case["decoded_latents"] + if domain not in case["selected_latents"] + } + inverse_selected_group_label = ( + "{" + ",".join(sorted(inverse_selected_latents)) + "}" ) - if loss_output is None: - continue - loss_name = case["loss_label"] - metrics[loss_name + "_loss"] = loss_output.loss - metrics.update({f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}) - cycle_losses.append(loss_name + "_loss") + + re_encoded_latents = gw_mod.encode(inverse_selected_latents) + re_selection_scores = selection_mod( + inverse_selected_latents, re_encoded_latents + ) + re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores) + re_decoded_latents = gw_mod.decode( + re_fused_latents, domains=case["selected_latents"].keys() + ) + + for domain, target in case["selected_latents"].items(): + loss_name = ( + f"from_{case['selected_group_label']}_" + f"through_{inverse_selected_group_label}_to_{domain}_" + f"case_{case['group_name']}" + ) + loss_output = domain_mods[domain].compute_cy_loss( + re_decoded_latents[domain], target, case["raw_group"][domain] + ) + if loss_output is None: + continue + + metrics[loss_name + "_loss"] = loss_output.loss + metrics.update( + {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()} + ) + cycle_losses.append(loss_name + "_loss") if cycle_losses: metrics["cycles"] = torch.mean( @@ -798,7 +803,12 @@ def step( broadcast_result = self.broadcast_loss(domain_latents, raw_data) metrics.update(broadcast_result["metrics"]) metrics.update( - cycle_loss_from_broadcast(self.domain_mods, broadcast_result["cycle_cases"]) + cycle_loss_from_broadcast( + self.gw_mod, + self.selection_mod, + self.domain_mods, + broadcast_result["cycle_cases"], + ) ) loss = combine_loss(metrics, self.loss_coefs) From 3dca8a4a968c608600f90424eb2d66ac636c4343 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 16:43:12 +0000 Subject: [PATCH 18/33] Ensure LearnedAttention only builds selected key projection --- shimmer/modules/selection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index cc7531c..ca2e15d 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -185,10 +185,14 @@ def __init__( # Projections self.query_layer = nn.Linear(self.gw_dim, self.head_size) - self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) - self.per_key_layers = nn.ModuleDict( - {d: nn.Linear(self.gw_dim, self.head_size) for d in self.domain_names} - ) + if self.per_domain_keys: + self.per_key_layers = nn.ModuleDict( + {d: nn.Linear(self.gw_dim, self.head_size) for d in self.domain_names} + ) + self.shared_key_layer = None + else: + self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers = None @staticmethod def _calc_attention( From c192aac03377b81ac79515f186d04c5172748ae7 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 16:48:44 +0000 Subject: [PATCH 19/33] Annotate LearnedAttention key layers as optional for mypy --- shimmer/modules/selection.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index ca2e15d..9e7f63d 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -185,6 +185,8 @@ def __init__( # Projections self.query_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers: nn.ModuleDict[str, nn.Linear] | None + self.shared_key_layer: nn.Linear | None if self.per_domain_keys: self.per_key_layers = nn.ModuleDict( {d: nn.Linear(self.gw_dim, self.head_size) for d in self.domain_names} @@ -252,8 +254,16 @@ def forward( gw_latents = {d: gw_latents[d] for d in present} if self.per_domain_keys: + if self.per_key_layers is None: + raise RuntimeError( + "per_domain_keys=True but per-domain key layers are missing." + ) keys = {d: self.per_key_layers[d](gw_latents[d]) for d in present} else: + if self.shared_key_layer is None: + raise RuntimeError( + "per_domain_keys=False but shared key layer is missing." + ) proj = self.shared_key_layer keys = {d: proj(gw_latents[d]) for d in present} From bf3f75c038d110b8ab0ea54d20963d983aeb695f Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 16:53:37 +0000 Subject: [PATCH 20/33] Fix mypy issues in LearnedAttention and ckpt migration CLI --- shimmer/cli/ckpt_migration.py | 6 +++--- shimmer/modules/selection.py | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/shimmer/cli/ckpt_migration.py b/shimmer/cli/ckpt_migration.py index bbc3303..c573be6 100644 --- a/shimmer/cli/ckpt_migration.py +++ b/shimmer/cli/ckpt_migration.py @@ -10,9 +10,9 @@ @click.argument( "paths", nargs=-1, - type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False), + type=click.Path(exists=True, file_okay=True, dir_okay=False), ) -def migrate_ckpt(paths: Sequence[Path]): +def migrate_ckpt(paths: Sequence[str]): """ Script to migrate a list of checkpoints. This can be called with: @@ -24,4 +24,4 @@ def migrate_ckpt(paths: Sequence[Path]): Internally, this calls `shimmer.utils.migrate_model` for each of the given paths. """ for path in paths: - migrate_model(path) + migrate_model(Path(path)) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index 9e7f63d..c392143 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping +from typing import cast import torch import torch.nn as nn @@ -185,7 +186,7 @@ def __init__( # Projections self.query_layer = nn.Linear(self.gw_dim, self.head_size) - self.per_key_layers: nn.ModuleDict[str, nn.Linear] | None + self.per_key_layers: nn.ModuleDict | None self.shared_key_layer: nn.Linear | None if self.per_domain_keys: self.per_key_layers = nn.ModuleDict( @@ -258,7 +259,10 @@ def forward( raise RuntimeError( "per_domain_keys=True but per-domain key layers are missing." ) - keys = {d: self.per_key_layers[d](gw_latents[d]) for d in present} + keys = { + d: cast(nn.Linear, self.per_key_layers[d])(gw_latents[d]) + for d in present + } else: if self.shared_key_layer is None: raise RuntimeError( From 664dc4fd5ed9770bd705bb15e697af5d629e8141 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 16:57:11 +0000 Subject: [PATCH 21/33] Revert ckpt migration typing tweaks --- shimmer/cli/ckpt_migration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/shimmer/cli/ckpt_migration.py b/shimmer/cli/ckpt_migration.py index c573be6..bbc3303 100644 --- a/shimmer/cli/ckpt_migration.py +++ b/shimmer/cli/ckpt_migration.py @@ -10,9 +10,9 @@ @click.argument( "paths", nargs=-1, - type=click.Path(exists=True, file_okay=True, dir_okay=False), + type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False), ) -def migrate_ckpt(paths: Sequence[str]): +def migrate_ckpt(paths: Sequence[Path]): """ Script to migrate a list of checkpoints. This can be called with: @@ -24,4 +24,4 @@ def migrate_ckpt(paths: Sequence[str]): Internally, this calls `shimmer.utils.migrate_model` for each of the given paths. """ for path in paths: - migrate_model(Path(path)) + migrate_model(path) From 6b66406f2862ea685bca3685fcdda7922f57c772 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 10 Dec 2025 17:07:08 +0000 Subject: [PATCH 22/33] Fix mypy for ckpt migration CLI by using string paths --- shimmer/cli/ckpt_migration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/shimmer/cli/ckpt_migration.py b/shimmer/cli/ckpt_migration.py index bbc3303..c573be6 100644 --- a/shimmer/cli/ckpt_migration.py +++ b/shimmer/cli/ckpt_migration.py @@ -10,9 +10,9 @@ @click.argument( "paths", nargs=-1, - type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False), + type=click.Path(exists=True, file_okay=True, dir_okay=False), ) -def migrate_ckpt(paths: Sequence[Path]): +def migrate_ckpt(paths: Sequence[str]): """ Script to migrate a list of checkpoints. This can be called with: @@ -24,4 +24,4 @@ def migrate_ckpt(paths: Sequence[Path]): Internally, this calls `shimmer.utils.migrate_model` for each of the given paths. """ for path in paths: - migrate_model(path) + migrate_model(Path(path)) From 101f7f1065c4cd95be9e542326d993624c116f8f Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Thu, 11 Dec 2025 09:54:49 +0000 Subject: [PATCH 23/33] Add domain-latent key option to LearnedAttention --- shimmer/modules/selection.py | 90 +++++++++++++++++++++++++++------ tests/test_learned_attention.py | 52 +++++++++++++++++++ 2 files changed, 127 insertions(+), 15 deletions(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index c392143..a5860b7 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -165,6 +165,8 @@ class LearnedAttention(SelectionBase): Toggles: - per_domain_keys: use per-domain key projections instead of a shared one - stopgrad: detach GW latents before computing keys/query + - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains + - domain_dims: required when key_on_prefusion=False to size per-domain key layers """ def __init__( @@ -174,6 +176,8 @@ def __init__( head_size: int = 64, per_domain_keys: bool = False, stopgrad: bool = True, + key_on_prefusion: bool = True, + domain_dims: Mapping[str, int] | None = None, ): super().__init__() self.gw_dim = int(gw_dim) @@ -183,19 +187,49 @@ def __init__( # Toggles self.per_domain_keys = bool(per_domain_keys) self.stopgrad = bool(stopgrad) + self.key_on_prefusion = bool(key_on_prefusion) + self.domain_dims = dict(domain_dims) if domain_dims is not None else None # Projections self.query_layer = nn.Linear(self.gw_dim, self.head_size) self.per_key_layers: nn.ModuleDict | None self.shared_key_layer: nn.Linear | None - if self.per_domain_keys: + if self.key_on_prefusion: + if self.per_domain_keys: + self.per_key_layers = nn.ModuleDict( + { + d: nn.Linear(self.gw_dim, self.head_size) + for d in self.domain_names + } + ) + self.shared_key_layer = None + else: + self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers = None + else: + if not self.per_domain_keys: + raise ValueError( + "key_on_prefusion=False requires per_domain_keys=True because " + "domain latent dimensions can differ." + ) + if self.domain_dims is None: + raise ValueError( + "key_on_prefusion=False requires domain_dims for key projections." + ) + missing_dims = [ + d for d in self.domain_names if d not in self.domain_dims + ] + if missing_dims: + raise ValueError( + f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" + ) self.per_key_layers = nn.ModuleDict( - {d: nn.Linear(self.gw_dim, self.head_size) for d in self.domain_names} + { + d: nn.Linear(self.domain_dims[d], self.head_size) + for d in self.domain_names + } ) self.shared_key_layer = None - else: - self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) - self.per_key_layers = None @staticmethod def _calc_attention( @@ -234,25 +268,51 @@ def forward( """ Args: domains: mapping from domain name to GW latent (B, gw_dim) - encodings_pre_fusion: unused; kept for `SelectionBase` compatibility. + encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion) Returns: dict[str, torch.Tensor]: per-domain attention weights. """ - del encodings_pre_fusion # unused - - gw_latents: Mapping[str, torch.Tensor] = domains + domain_latents: Mapping[str, torch.Tensor] = domains - present = [d for d in self.domain_names if d in gw_latents] + present = [d for d in self.domain_names if d in domain_latents] if not present: raise ValueError( "LearnedAttention: no known domains present in gw_latents." ) + if self.key_on_prefusion: + if encodings_pre_fusion is None: + raise ValueError( + "key_on_prefusion=True requires encodings_pre_fusion inputs." + ) + key_source = encodings_pre_fusion + else: + key_source = domain_latents + + missing_keys = [d for d in present if d not in key_source] + if missing_keys: + raise ValueError( + f"Missing key latents for: {', '.join(sorted(missing_keys))}" + ) + + if encodings_pre_fusion is None: + query_source = domain_latents + else: + query_source = encodings_pre_fusion + + missing_query = [d for d in present if d not in query_source] + if missing_query: + raise ValueError( + f"Missing query latents for: {', '.join(sorted(missing_query))}" + ) + if self.stopgrad: - gw_latents = {d: t.detach() for d, t in gw_latents.items() if d in present} + key_latents = {d: key_source[d].detach() for d in present} + query_latents = {d: query_source[d].detach() for d in present} else: - gw_latents = {d: gw_latents[d] for d in present} + key_latents = {d: key_source[d] for d in present} + query_latents = {d: query_source[d] for d in present} if self.per_domain_keys: if self.per_key_layers is None: @@ -260,7 +320,7 @@ def forward( "per_domain_keys=True but per-domain key layers are missing." ) keys = { - d: cast(nn.Linear, self.per_key_layers[d])(gw_latents[d]) + d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d]) for d in present } else: @@ -269,9 +329,9 @@ def forward( "per_domain_keys=False but shared key layer is missing." ) proj = self.shared_key_layer - keys = {d: proj(gw_latents[d]) for d in present} + keys = {d: proj(key_latents[d]) for d in present} - stacked = torch.stack([gw_latents[d] for d in present], dim=0) # (D, B, F) + stacked = torch.stack([query_latents[d] for d in present], dim=0) # (D, B, F) query = self.query_layer(stacked.mean(0)) # (B, H) return self._calc_attention( diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py index 3144915..4fcec9a 100644 --- a/tests/test_learned_attention.py +++ b/tests/test_learned_attention.py @@ -1,3 +1,4 @@ +import pytest import torch from shimmer.modules.selection import LearnedAttention @@ -51,3 +52,54 @@ def test_learned_attention_stopgrad_toggle() -> None: torch.stack(list(trainable_weights.values())).sum().backward() assert train_latents["a"].grad is not None assert train_latents["b"].grad is not None + + +def test_learned_attention_domain_key_path() -> None: + domain_dims = {"a": 3, "b": 5} + selector = LearnedAttention( + gw_dim=4, + domain_names=domain_dims.keys(), + head_size=3, + per_domain_keys=True, + stopgrad=False, + key_on_prefusion=False, + domain_dims=domain_dims, + ) + + domain_latents = { + "a": torch.randn(6, 3, requires_grad=True), + "b": torch.randn(6, 5, requires_grad=True), + } + prefusion_latents = { + "a": torch.randn(6, 4, requires_grad=True), + "b": torch.randn(6, 4, requires_grad=True), + } + + weights = selector(domain_latents, encodings_pre_fusion=prefusion_latents) + + stacked = torch.stack([weights["a"], weights["b"]], dim=1) + assert torch.allclose(stacked.sum(dim=1), torch.ones(6)) + + +def test_learned_attention_domain_key_shared_layer_error() -> None: + domain_dims = {"a": 3, "b": 5} + with pytest.raises(ValueError): + LearnedAttention( + gw_dim=4, + domain_names=domain_dims.keys(), + head_size=3, + per_domain_keys=False, + stopgrad=True, + key_on_prefusion=False, + domain_dims=domain_dims, + ) + + with pytest.raises(ValueError): + LearnedAttention( + gw_dim=4, + domain_names=domain_dims.keys(), + head_size=3, + per_domain_keys=True, + stopgrad=True, + key_on_prefusion=False, + ) From d1badaeb24800658ec4970476788520e8d85a1a7 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Thu, 11 Dec 2025 09:58:18 +0000 Subject: [PATCH 24/33] Format LearnedAttention per ruff --- shimmer/modules/selection.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index a5860b7..062649b 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -216,9 +216,7 @@ def __init__( raise ValueError( "key_on_prefusion=False requires domain_dims for key projections." ) - missing_dims = [ - d for d in self.domain_names if d not in self.domain_dims - ] + missing_dims = [d for d in self.domain_names if d not in self.domain_dims] if missing_dims: raise ValueError( f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" From 02167d165f2829206ae05e58b098578253d48206 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Thu, 11 Dec 2025 10:13:06 +0000 Subject: [PATCH 25/33] Pass domain key options through init_learned_attention --- shimmer/modules/global_workspace.py | 22 ++++++++++ tests/test_learned_attention.py | 67 ++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 64a18ee..ffaffc7 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -797,6 +797,8 @@ def init_learned_attention( head_size: int = 64, per_domain_keys: bool = False, stopgrad: bool = True, + key_on_prefusion: bool = True, + domain_dims: Mapping[str, int] | None = None, ) -> LearnedAttention: """ Initialize and attach a learned content-based attention module. @@ -805,12 +807,32 @@ def init_learned_attention( the current workspace (uses `workspace_dim` and domain names from `domain_mods`), ensuring its parameters are tracked by Lightning/torch. """ + if not key_on_prefusion and not per_domain_keys: + raise ValueError( + "key_on_prefusion=False requires per_domain_keys=True because " + "domain latent dimensions can differ." + ) + + final_domain_dims = domain_dims + if not key_on_prefusion: + if final_domain_dims is None: + final_domain_dims = { + name: mod.latent_dim for name, mod in self.domain_mods.items() + } + missing = [d for d in self.domain_mods if d not in final_domain_dims] + if missing: + raise ValueError( + f"Missing domain_dims for: {', '.join(sorted(missing))}" + ) + selection = LearnedAttention( gw_dim=self.workspace_dim, domain_names=self.domain_mods.keys(), head_size=head_size, per_domain_keys=per_domain_keys, stopgrad=stopgrad, + key_on_prefusion=key_on_prefusion, + domain_dims=final_domain_dims, ) self.selection_mod = selection return selection diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py index 4fcec9a..d71796f 100644 --- a/tests/test_learned_attention.py +++ b/tests/test_learned_attention.py @@ -1,6 +1,12 @@ import pytest import torch +import torch.nn as nn +from shimmer.modules.domain import DomainModule +from shimmer.modules.global_workspace import ( + GlobalWorkspaceFusion, + freeze_domain_modules, +) from shimmer.modules.selection import LearnedAttention @@ -94,10 +100,69 @@ def test_learned_attention_domain_key_shared_layer_error() -> None: domain_dims=domain_dims, ) + +class _DummyDomain(DomainModule): + def __init__(self, latent_dim: int): + super().__init__(latent_dim) + + def encode(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover - simple stub + return x + + def decode(self, z: torch.Tensor) -> torch.Tensor: # pragma: no cover - simple stub + return z + + +def test_global_workspace_init_learned_attention_domain_dims() -> None: + domain_mods = freeze_domain_modules({"a": _DummyDomain(3), "b": _DummyDomain(5)}) + gw_encoders = {"a": nn.Identity(), "b": nn.Identity()} + gw_decoders = {"a": nn.Identity(), "b": nn.Identity()} + + gw = GlobalWorkspaceFusion( + domain_mods=domain_mods, + gw_encoders=gw_encoders, + gw_decoders=gw_decoders, + workspace_dim=4, + loss_coefs={"contrastives": 0.0}, + ) + + selector = gw.init_learned_attention( + head_size=2, + per_domain_keys=True, + stopgrad=False, + key_on_prefusion=False, + ) + + assert selector.key_on_prefusion is False + assert selector.per_key_layers is not None + assert selector.per_key_layers["a"].weight.shape[1] == 3 + assert selector.per_key_layers["b"].weight.shape[1] == 5 + + +def test_global_workspace_init_learned_attention_shared_error() -> None: + domain_mods = freeze_domain_modules({"a": _DummyDomain(3), "b": _DummyDomain(5)}) + gw_encoders = {"a": nn.Identity(), "b": nn.Identity()} + gw_decoders = {"a": nn.Identity(), "b": nn.Identity()} + + gw = GlobalWorkspaceFusion( + domain_mods=domain_mods, + gw_encoders=gw_encoders, + gw_decoders=gw_decoders, + workspace_dim=4, + loss_coefs={"contrastives": 0.0}, + ) + + with pytest.raises(ValueError): + gw.init_learned_attention( + head_size=2, + per_domain_keys=False, + stopgrad=True, + key_on_prefusion=False, + ) + with pytest.raises(ValueError): LearnedAttention( gw_dim=4, - domain_names=domain_dims.keys(), + domain_names=["a", "b"], head_size=3, per_domain_keys=True, stopgrad=True, From 8388448dc3973ac58a5faa3d7617cf04b7b170d5 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Thu, 11 Dec 2025 10:15:26 +0000 Subject: [PATCH 26/33] Drop duplicate missing-domain-dims check in attention tests --- tests/test_learned_attention.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py index d71796f..b361c70 100644 --- a/tests/test_learned_attention.py +++ b/tests/test_learned_attention.py @@ -158,13 +158,3 @@ def test_global_workspace_init_learned_attention_shared_error() -> None: stopgrad=True, key_on_prefusion=False, ) - - with pytest.raises(ValueError): - LearnedAttention( - gw_dim=4, - domain_names=["a", "b"], - head_size=3, - per_domain_keys=True, - stopgrad=True, - key_on_prefusion=False, - ) From 553ab6f139170c7c04007683449d04b27234bcf1 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Thu, 11 Dec 2025 10:35:07 +0000 Subject: [PATCH 27/33] Warn before initializing LearnedAttention --- shimmer/modules/global_workspace.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index ffaffc7..07f8488 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Iterable, Mapping from enum import Enum, auto from pathlib import Path @@ -807,6 +808,15 @@ def init_learned_attention( the current workspace (uses `workspace_dim` and domain names from `domain_mods`), ensuring its parameters are tracked by Lightning/torch. """ + warnings.warn( + ( + "LearnedAttention is best used after pretraining the global workspace " + "with a simpler selection (e.g., random or single-domain). " + "This path is minimally validated; use at your own risk." + ), + UserWarning, + stacklevel=2, + ) if not key_on_prefusion and not per_domain_keys: raise ValueError( "key_on_prefusion=False requires per_domain_keys=True because " From 887cbcfa45dd427746149e24292c5ceca0587529 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 17 Dec 2025 11:08:13 +0000 Subject: [PATCH 28/33] Refactor broadcast naming and stop logging metrics --- shimmer/modules/losses.py | 55 +++++++++++++++++---------------------- tests/test_broadcast.py | 20 +++++++------- 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index b01cf65..45a8a48 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -367,7 +367,7 @@ class CycleCase(TypedDict): raw_group: Mapping[str, object] -class BroadcastLossResult(TypedDict): +class BroadcastResult(TypedDict): """ Broadcast loss output without cycle computation. @@ -535,13 +535,13 @@ def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]: yield perm -def broadcast_loss( +def broadcast( gw_mod: GWModuleBase, selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT, -) -> BroadcastLossResult: +) -> BroadcastResult: """ Computes broadcast demi-cycle (with fused) and translation losses, and prepares precomputed artifacts for cycle losses. @@ -574,7 +574,7 @@ def broadcast_loss( raw_data (`RawDomainGroupsT`): raw input data Returns: - `BroadcastLossResult`: demi/translation metrics plus precomputed cycle data. + `BroadcastResult`: demi/translation metrics plus precomputed cycle data. """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -658,7 +658,7 @@ def broadcast_loss( ) metrics.update(losses) - return BroadcastLossResult(metrics=metrics, cycle_cases=cycle_cases) + return BroadcastResult(metrics=metrics, cycle_cases=cycle_cases) def cycle_loss_from_broadcast( @@ -674,7 +674,7 @@ def cycle_loss_from_broadcast( gw_mod: GW module used for encoding/decoding. selection_mod: selection module used during fusion. domain_mods: domain modules used to compute the losses. - cycle_cases: precomputed cycle data produced by `broadcast_loss`. + cycle_cases: precomputed cycle data produced by `broadcast`. Returns: Metrics dict containing per-case losses/metrics and aggregate `cycles`. @@ -772,10 +772,10 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - def broadcast_loss( + def broadcast( self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT - ) -> BroadcastLossResult: - return broadcast_loss( + ) -> BroadcastResult: + return broadcast( self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) @@ -797,29 +797,22 @@ def step( A LossOutput object containing the loss and metrics for this step. """ - metrics: dict[str, torch.Tensor] = {} - - metrics.update(self.contrastive_loss(domain_latents)) - broadcast_result = self.broadcast_loss(domain_latents, raw_data) - metrics.update(broadcast_result["metrics"]) - metrics.update( - cycle_loss_from_broadcast( - self.gw_mod, - self.selection_mod, - self.domain_mods, - broadcast_result["cycle_cases"], - ) + contrastive_metrics = self.contrastive_loss(domain_latents) + broadcast_result = self.broadcast(domain_latents, raw_data) + cycle_metrics = cycle_loss_from_broadcast( + self.gw_mod, + self.selection_mod, + self.domain_mods, + broadcast_result["cycle_cases"], ) - loss = combine_loss(metrics, self.loss_coefs) + loss_inputs: dict[str, torch.Tensor] = { + **contrastive_metrics, + **broadcast_result["metrics"], + **cycle_metrics, + } - metrics["broadcast_loss"] = torch.stack( - [ - metrics[name] - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and name != "contrastives" - ], - dim=0, - ).mean() + loss = combine_loss(loss_inputs, self.loss_coefs) - return LossOutput(loss, metrics) + # Do not log broadcast components; keep non-broadcast metrics only. + return LossOutput(loss, metrics=dict(contrastive_metrics)) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index ddbba22..e2a6d20 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -27,7 +27,7 @@ def compute_loss( return LossOutput(loss=loss) # Constructing LossOutput with the loss -def test_broadcast_loss(): +def test_broadcast(): domain_mods: dict[str, DomainModule] = { "domain1": DummyDomainModule(latent_dim=10), "domain2": DummyDomainModule(latent_dim=10), @@ -55,7 +55,7 @@ def test_broadcast_loss(): learn_logit_scale=False, ) - # Adjusting the dummy data to fit the expected input structure for broadcast_loss + # Adjusting the dummy data to fit the expected input structure for broadcast # Now using a frozenset for the keys to match LatentsDomainGroupsT latent_domains = { frozenset(["domain1", "domain2"]): { @@ -64,20 +64,18 @@ def test_broadcast_loss(): } } - # Test broadcast_loss with the corrected structure - result = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) + # Test broadcast with the corrected structure + result = gw_fusion.loss_mod.broadcast(latent_domains, latent_domains) assert "metrics" in result and "cycle_cases" in result - # Cycle metrics are computed in step(), not within broadcast_loss + # Cycle metrics are computed in step(), not within broadcast metrics = result["metrics"] assert all(metric in metrics for metric in ["demi_cycles", "translations"]) - # Cycle metrics should appear after the step call + # Broadcast metrics should not be logged from step() step_output = gw_fusion.loss_mod.step(latent_domains, latent_domains, mode="train") - er_msg = "Demi-cycle, cycle and translation metrics should be in the output." - assert all( - metric in step_output.metrics - for metric in ["demi_cycles", "cycles", "translations"] - ) + assert "demi_cycles" not in step_output.metrics + assert "translations" not in step_output.metrics + assert "cycles" not in step_output.metrics er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." assert all( From 17b467ae71194752b4cdb7ab72a9236bad030a2b Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 17 Dec 2025 11:59:12 +0000 Subject: [PATCH 29/33] Warn when loss coef missing --- shimmer/modules/losses.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 45a8a48..9c15c09 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import warnings from collections.abc import Generator, Mapping from itertools import product from typing import TypedDict @@ -287,9 +288,9 @@ class LossCoefs(TypedDict, total=False): """ Dict of loss coefficients used in the GWLosses. - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. + If one is not provided, the coefficient is assumed to be 0 and will not be logged + (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but + not take part in the total loss. """ demi_cycles: float @@ -309,9 +310,9 @@ class BroadcastLossCoefs(TypedDict, total=False): """ Dict of loss coefficients used in the GWLossesFusion. - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. + If one is not provided, the coefficient is assumed to be 0 and will not be logged + (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but + not take part in the total loss. """ contrastives: float @@ -346,6 +347,20 @@ def combine_loss( Returns: `torch.Tensor`: the combined loss. """ + missing = { + name + for name in _EXPECTED_COEF_KEYS + if name in metrics and name not in coefs + } + for name in sorted(missing): + if name not in _MISSING_COEFS_WARNED: + warnings.warn( + f"Loss coefficient '{name}' not provided; defaulting to 0.", + UserWarning, + stacklevel=2, + ) + _MISSING_COEFS_WARNED.add(name) + loss = torch.stack( [ metrics[name] * coef @@ -357,6 +372,10 @@ def combine_loss( return loss +_EXPECTED_COEF_KEYS = {"contrastives", "demi_cycles", "cycles", "translations"} +_MISSING_COEFS_WARNED: set[str] = set() + + class CycleCase(TypedDict): """Container for precomputed cycle inputs to avoid recomputation.""" From c6506ef001ca76a6b4b6041f69e8abe9d88f04ec Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 17 Dec 2025 11:59:24 +0000 Subject: [PATCH 30/33] Add warning test for missing loss coef --- tests/test_loss_coefs_warning.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_loss_coefs_warning.py diff --git a/tests/test_loss_coefs_warning.py b/tests/test_loss_coefs_warning.py new file mode 100644 index 0000000..01928c2 --- /dev/null +++ b/tests/test_loss_coefs_warning.py @@ -0,0 +1,20 @@ +import warnings + +import torch + +from shimmer.modules.losses import combine_loss + + +def test_missing_loss_coef_warns_and_defaults_to_zero() -> None: + metrics = { + "demi_cycles": torch.tensor(1.0), + "contrastives": torch.tensor(2.0), + } + coefs = {"contrastives": 1.0} + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + loss = combine_loss(metrics, coefs) + + assert any("demi_cycles" in str(w.message) for w in caught) + assert torch.isclose(loss, torch.tensor(2.0)) From 4dc4e7ccf63df8fc70010b8bbfdecda8ae1632cd Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 17 Dec 2025 12:09:59 +0000 Subject: [PATCH 31/33] Clarify broadcast docstring --- shimmer/modules/losses.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 9c15c09..4e0a4cf 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -565,16 +565,15 @@ def broadcast( Computes broadcast demi-cycle (with fused) and translation losses, and prepares precomputed artifacts for cycle losses. - This return multiple metrics: + This returns multiple metrics: * `demi_cycles` - * `cycles` * `translations` * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order (e.g. "from_{t,v}_to_t_loss"). Note: fused cases are aggregated into `demi_cycles`. * `from_{start_group}_to_{domain}_{metric}` with additional metrics provided by - the domain_mod's `compute_broadcast_loss` output + the domain module's loss outputs * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` where `{start_group}`, `{target_group}` and `{case_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order @@ -582,8 +581,7 @@ def broadcast( domains, `{target_group}` the target domains used for the cycle and `{case_group}` all available domains participating to the loss. * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` - additional metrics provided by the domain_mod's `compute_broadcast_loss` - output + additional metrics provided by the domain module's loss outputs Args: gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use From 8601c928c318a6ef93cdb407feea8effc57db2c7 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Wed, 17 Dec 2025 14:13:18 +0000 Subject: [PATCH 32/33] Run ruff format --- shimmer/modules/losses.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 4e0a4cf..3dd5859 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod import warnings +from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from itertools import product from typing import TypedDict @@ -348,9 +348,7 @@ def combine_loss( `torch.Tensor`: the combined loss. """ missing = { - name - for name in _EXPECTED_COEF_KEYS - if name in metrics and name not in coefs + name for name in _EXPECTED_COEF_KEYS if name in metrics and name not in coefs } for name in sorted(missing): if name not in _MISSING_COEFS_WARNED: From 3213179f9bd531d72fc25f2b8d5c62b417359068 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET Date: Fri, 19 Dec 2025 17:57:40 +0000 Subject: [PATCH 33/33] Restore broadcast metrics logging (minus aggregate) --- shimmer/modules/losses.py | 29 +++++++++++++++-------------- tests/test_broadcast.py | 8 ++++---- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 3dd5859..202040b 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -812,22 +812,23 @@ def step( A LossOutput object containing the loss and metrics for this step. """ - contrastive_metrics = self.contrastive_loss(domain_latents) + metrics: dict[str, torch.Tensor] = {} + + metrics.update(self.contrastive_loss(domain_latents)) broadcast_result = self.broadcast(domain_latents, raw_data) - cycle_metrics = cycle_loss_from_broadcast( - self.gw_mod, - self.selection_mod, - self.domain_mods, - broadcast_result["cycle_cases"], + metrics.update(broadcast_result["metrics"]) + metrics.update( + cycle_loss_from_broadcast( + self.gw_mod, + self.selection_mod, + self.domain_mods, + broadcast_result["cycle_cases"], + ) ) - loss_inputs: dict[str, torch.Tensor] = { - **contrastive_metrics, - **broadcast_result["metrics"], - **cycle_metrics, - } + loss = combine_loss(metrics, self.loss_coefs) - loss = combine_loss(loss_inputs, self.loss_coefs) + # Do not expose the deprecated broadcast_loss aggregate. + metrics.pop("broadcast_loss", None) - # Do not log broadcast components; keep non-broadcast metrics only. - return LossOutput(loss, metrics=dict(contrastive_metrics)) + return LossOutput(loss, metrics) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index e2a6d20..7900e86 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -71,11 +71,11 @@ def test_broadcast(): metrics = result["metrics"] assert all(metric in metrics for metric in ["demi_cycles", "translations"]) - # Broadcast metrics should not be logged from step() + # Broadcast metrics should be logged from step (but not the deprecated aggregate) step_output = gw_fusion.loss_mod.step(latent_domains, latent_domains, mode="train") - assert "demi_cycles" not in step_output.metrics - assert "translations" not in step_output.metrics - assert "cycles" not in step_output.metrics + assert "broadcast_loss" not in step_output.metrics + for metric in ["demi_cycles", "translations", "cycles"]: + assert metric in step_output.metrics er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." assert all(