diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 5e8355d..763bcfe 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -33,7 +33,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, @@ -85,7 +84,6 @@ "contrastive_loss", "ContrastiveLoss", "LossCoefs", - "BroadcastLossCoefs", "combine_loss", "GWLossesBase", "GWLosses2Domains", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index cd5957e..9145a76 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -27,7 +27,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, @@ -63,7 +62,6 @@ "contrastive_loss", "ContrastiveLoss", "LossCoefs", - "BroadcastLossCoefs", "combine_loss", "GWLossesBase", "GWLosses2Domains", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 07f8488..4d0c251 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -23,7 +23,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses, GWLosses2Domains, GWLossesBase, @@ -721,7 +720,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: BroadcastLossCoefs | Mapping[str, float], + loss_coefs: LossCoefs | Mapping[str, float], selection_temperature: float = 0.2, selection_mod: SelectionBase | None = None, optim_lr: float = 1e-3, diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 202040b..e6ed67c 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -306,31 +306,9 @@ class LossCoefs(TypedDict, total=False): """Contrastive loss coefficient.""" -class BroadcastLossCoefs(TypedDict, total=False): - """ - Dict of loss coefficients used in the GWLossesFusion. - - If one is not provided, the coefficient is assumed to be 0 and will not be logged - (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but - not take part in the total loss. - """ - - contrastives: float - """Contrastive loss coefficient.""" - - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles aggregate fused cases too.""" - - cycles: float - """cycles loss coefficient. Cycles can be many-to-one""" - - translations: float - """translation loss coefficient. Translation, like cycles, can be many-to-one.""" - - def combine_loss( metrics: dict[str, torch.Tensor], - coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs, + coefs: Mapping[str, float] | LossCoefs, ) -> torch.Tensor: """ Combines the metrics according to the ones selected in coefs @@ -626,12 +604,10 @@ def broadcast( continue ground_truth = latents[domain] - if num_active_domains == 1 and domain in selected_latents: + if domain in selected_latents: loss_fn = domain_mods[domain].compute_dcy_loss - elif domain not in selected_latents: - loss_fn = domain_mods[domain].compute_tr_loss else: - loss_fn = domain_mods[domain].compute_dcy_loss + loss_fn = domain_mods[domain].compute_tr_loss loss_output = loss_fn( pred, ground_truth, raw_data[group_domains][domain] @@ -645,12 +621,10 @@ def broadcast( {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()} ) - if num_active_domains == 1 and domain in selected_latents: + if domain in selected_latents: demi_cycle_losses.append(loss_label + "_loss") - elif domain not in selected_latents: + else: translation_losses.append(loss_label + "_loss") - else: # fused loss counts toward demi_cycles aggregate - demi_cycle_losses.append(loss_label + "_loss") if num_active_domains < num_total_domains: cycle_cases.append( @@ -752,7 +726,7 @@ def __init__( gw_mod: GWModule, selection_mod: SelectionBase, domain_mods: dict[str, DomainModule], - loss_coefs: BroadcastLossCoefs | Mapping[str, float], + loss_coefs: LossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, ): """ @@ -762,7 +736,7 @@ def __init__( gw_mod: The GWModule for the global workspace. selection_mod: The selection mechanism for the model. domain_mods: A mapping of domain names to their respective DomainModule. - loss_coefs (`BroadcastLossCoefs`): coefs for the losses + loss_coefs (`LossCoefs`): coefs for the losses contrastive_fn: The function used for computing contrastive loss. """ super().__init__() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 7900e86..ff5e39e 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -5,7 +5,7 @@ from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.global_workspace import GlobalWorkspaceFusion -from shimmer.modules.losses import BroadcastLossCoefs +from shimmer.modules.losses import LossCoefs class DummyDomainModule(DomainModule): @@ -35,7 +35,7 @@ def test_broadcast(): gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} workspace_dim = 10 - loss_coefs: BroadcastLossCoefs = { + loss_coefs: LossCoefs = { "cycles": 1.0, "demi_cycles": 1.0, "translations": 1.0,