Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ quality:
$(TOOL) ty check $(check_dirs)

test:
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib
WANDB_MODE=offline $(RUN) pytest -vv --import-mode=importlib
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ npu = [
"decorator",
"psutil",
"importlib_metadata",
"lightning @ git+https://github.com/SmallPigPeppa/pytorch-lightning.git@npu_support_fp16_mixed",
"lycil[wandb]"
"lycil[wandb]",
# "lightning @ git+https://github.com/SmallPigPeppa/pytorch-lightning.git@npu_support_fp16_mixed",
]
ghactions = [
"torch @ https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
Expand All @@ -53,7 +53,7 @@ lightning = [
"lightning>=2.2.0,<2.6.0",
]
wandb = [
"wandb>=0.12.10,<0.23.0"
"wandb>=0.20.0,<0.25.0,!=0.24.0"
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/lycil/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_CLASSIFIER_HEADS: dict[str, tuple[type[nn.Module], dict]] = {
# key: (class, {optional kwargs})
"linear": (SimpleLinear, {}),
"cosine": (CosineLinear, {"learn_scale": True}),
"cosine": (CosineLinear, {"num_proxy": 10, "to_reduce": True, "learn_scale": True}),
}


Expand Down Expand Up @@ -70,11 +70,13 @@ def expand_head(module: nn.Module, num_new: int) -> nn.Module:
if isinstance(module, CosineLinear):
new_linear = SplitCosineLinear.from_cosine_linear(module, num_new)
new_linear.old_head.requires_grad_(False)
new_linear.new_head.requires_grad_(True)
return new_linear

if isinstance(module, SplitCosineLinear):
new_linear = SplitCosineLinear.from_split_cosine_linear(module, num_new)
new_linear.old_head.requires_grad_(False)
new_linear.new_head.requires_grad_(True)
return new_linear

raise NotImplementedError(f"Classifier not expandable: {type(module)}.")
3 changes: 1 addition & 2 deletions src/lycil/classifier/linears.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import math
from typing import TypedDict, Callable
from typing import Callable, TypedDict

import torch
from torch import nn
from torch.nn import functional as F


LinearHead = Callable[[torch.Tensor], dict[str, torch.Tensor]]


Expand Down
6 changes: 4 additions & 2 deletions src/lycil/data/hfmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(
self.buffer: BaseExemplarBuffer | None = (
BaseExemplarBuffer(**buffer_kwargs) if buffer_kwargs is not None else None
)
self.train_filter_fn: Callable[[dict], bool] | None = None
self.use_buffer: bool = True

self._cur_task_id: int = 0
self.dataset: DatasetDict
Expand Down Expand Up @@ -227,10 +229,10 @@ def get_dataloader(
def train_dataloader(self):
return self.get_dataloader(
split=self._split_train,
filter_fn=self.is_label_in_cur_task,
filter_fn=self.train_filter_fn or self.is_label_in_cur_task,
transform_name=self.get_effective_transform_name("train"),
loader_kwargs=self.train_loader_kwargs,
use_buffer=True,
use_buffer=self.use_buffer,
)

def val_dataloader(self):
Expand Down
67 changes: 44 additions & 23 deletions src/lycil/learner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,18 @@ def sync_with_datamodule(self, dm: "HFDataModule"):
Args:
dm (HFDataModule): Data module to sync with.
"""
self.task_id = dm.get_current_task()
dm_task_id = dm.get_current_task()
if self.task_id is not None and dm_task_id == self.task_id:
# in sync, no update
return
if self.task_id is not None and self.task_id < 0:
# a special bypass rule for buffer-only training
# this will disable head expansion, to do this, you should:
# set by `learner.task_id=-2` and
# reset by `learner.task_id=cur_task_id`.
return

self.task_id = dm_task_id

incoming_expansion = dm.num_seen_classes - (self.num_seen_classes or 0)
if incoming_expansion <= 0:
Expand All @@ -85,6 +96,8 @@ def sync_with_datamodule(self, dm: "HFDataModule"):
+ "Ensure that `sync_with_datamodule()` is called after datamodule updates."
)

self.expand_head(incoming_expansion)

self.num_old_classes = self.num_seen_classes or 0
self.num_seen_classes = dm.num_seen_classes

Expand Down Expand Up @@ -131,41 +144,50 @@ def forward_layerwise(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
@abstractmethod
def update_memory(self, *args, **kwargs): ...

def configure_optimizers(self):
params = [p for p in self.parameters() if p.requires_grad]

optim_kwargs = (
self.per_task_optim_args.get(self.task_id)
or self.per_task_optim_args.get(-1)
or {}
)
opt_type = optim_kwargs.pop("type", "sgd")
@staticmethod
def _get_optimizer(*args, **kwargs):
opt_type = kwargs.pop("type", "sgd")
match opt_type:
case "sgd":
optim = torch.optim.SGD(params, **optim_kwargs)
return torch.optim.SGD(*args, **kwargs)
case "adamw":
optim = torch.optim.AdamW(params, **optim_kwargs)
return torch.optim.AdamW(*args, **kwargs)
case _:
raise NotImplementedError(f"Unsupported optimizer: `{opt_type}`")

sched_kwargs = (
self.per_task_sched_args.get(self.task_id)
or self.per_task_sched_args.get(-1)
or {}
)
sched_type = sched_kwargs.pop("type", "linear_warmup_cosine_annealing")
@staticmethod
def _get_scheduler(*args, **kwargs):
sched_type = kwargs.pop("type", "linear_warmup_cosine_annealing")
match sched_type:
case "linear_warmup_cosine_annealing":
sched = LinearWarmupCosineAnnealingLR(optim, **sched_kwargs)
return LinearWarmupCosineAnnealingLR(*args, **kwargs)
case "cosine_annealing":
sched = lr_scheduler.CosineAnnealingLR(optim, **sched_kwargs)
return lr_scheduler.CosineAnnealingLR(*args, **kwargs)
case "step_lr":
sched = lr_scheduler.StepLR(optim, **sched_kwargs)
return lr_scheduler.StepLR(*args, **kwargs)
case "multi_step_lr":
sched = lr_scheduler.MultiStepLR(optim, **sched_kwargs)
return lr_scheduler.MultiStepLR(*args, **kwargs)
case _:
raise NotImplementedError(f"Unsupported scheduler: `{sched_type}`")

def configure_optimizers(self):
params = [p for p in self.parameters() if p.requires_grad]

# a waterfall lookup for optimizer/scheduler kwargs:
# per-task specific > default (-1) > empty dict
optim_kwargs = (
self.per_task_optim_args.get(self.task_id)
or self.per_task_optim_args.get(-1)
or {}
)
sched_kwargs = (
self.per_task_sched_args.get(self.task_id)
or self.per_task_sched_args.get(-1)
or {}
)
optim = self._get_optimizer(params, **optim_kwargs)
sched = self._get_scheduler(optim, **sched_kwargs)

return {
"optimizer": optim,
"lr_scheduler": {"scheduler": sched, "interval": "epoch"},
Expand Down Expand Up @@ -196,7 +218,6 @@ def setup(self, stage) -> None:
if stage == "fit":
dm: HFDataModule = self.trainer.datamodule # ty: ignore[unresolved-attribute]
self.sync_with_datamodule(dm)
self.expand_head(self.num_seen_classes - self.num_old_classes)

def on_fit_end(self):
self.snapshot_old()
Expand Down
172 changes: 172 additions & 0 deletions src/lycil/learner/podnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import math

import torch
import torch.nn.functional as F

from .icarl import ICaRL


def nca(
similarities: torch.Tensor,
targets: torch.Tensor,
class_weights: torch.Tensor | None = None,
scale: float = 1.0,
margin: float = 0.6,
exclude_pos_denominator: bool = True,
hinge_proxynca: bool = False,
) -> torch.Tensor:
margins = torch.zeros_like(similarities)
margins[torch.arange(margins.shape[0]), targets] = margin
similarities = scale * (similarities - margin)

if exclude_pos_denominator:
similarities = similarities - similarities.max(1)[0].view(-1, 1)

disable_pos = torch.zeros_like(similarities)
disable_pos[torch.arange(len(similarities)), targets] = similarities[
torch.arange(len(similarities)), targets
]

numerator = similarities[torch.arange(similarities.shape[0]), targets]
denominator = similarities - disable_pos

losses = numerator - torch.log(torch.exp(denominator).sum(-1))
if class_weights is not None:
losses = class_weights[targets] * losses

losses = -losses
if hinge_proxynca:
losses = torch.clamp(losses, min=0.0)

loss = torch.mean(losses)
return loss

return F.cross_entropy(
similarities, targets, weight=class_weights, reduction="mean"
)


def pod_spatial_loss(
old_fmaps: dict[str, torch.Tensor],
new_fmaps: dict[str, torch.Tensor],
normalize: bool = True,
distill_on_layers: list[str] = ["l1", "l2", "l3", "l4"],
) -> torch.Tensor:
loss: torch.Tensor = None # ty: ignore[invalid-assignment]
for layer in distill_on_layers:
a = old_fmaps[layer]
b = new_fmaps[layer]
assert a.shape == b.shape, "Shape error"

a = torch.pow(a, 2)
b = torch.pow(b, 2)

a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w]
b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w]
a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h]
b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h]

a = torch.cat([a_h, a_w], dim=-1)
b = torch.cat([b_h, b_w], dim=-1)

if normalize:
a = F.normalize(a, dim=1, p=2)
b = F.normalize(b, dim=1, p=2)

layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1))
if loss is None:
loss = layer_loss
else:
loss += layer_loss

return loss / len(distill_on_layers)


class PODNet(ICaRL):
r"""`PODNet`_: Pooled Outputs Distillation for Small-Tasks Incremental Learning. (Douillard et al., ECCV 2020).
- Exemplar memory: herding + NME-based evaluation
- Loss :math:`L = L_\text{NCA} + \lambda * \alpha_\text{task} * (L_\text{flat} + L_\text{spatial})`.

Args:
lambda_spatial (float, optional): Weight for spatial distillation loss. (default: 5.0)
lambda_flat (float, optional): Weight for flat distillation loss. (default: 1.0)
args: See :class:`BaseLearner` for other args.
kwargs: See :class:`BaseLearner` for other args.

.. _PODNet:
https://arxiv.org/abs/2004.13513
"""

def __init__(
self,
*args,
lambda_spatial: float = 5.0,
lambda_flat: float = 1.0,
**kwargs,
):
super().__init__(*args, **kwargs)

self.lambda_spatial = float(lambda_spatial)
self.lambda_flat = float(lambda_flat)

@property
def task_factor(self) -> float:
if self.task_id == 0:
return 0

return math.sqrt(
self.num_seen_classes / (self.num_seen_classes - self.num_old_classes)
)

def training_step(
self, batch: dict[str, torch.Tensor], batch_idx: int
) -> torch.Tensor:
x, y = self.unpack_batch(batch)
new_fmap = self.forward_layerwise(x)

# ce on all classes
loss_lsc = nca(new_fmap["logits"], y)

if self.task_id > 0:
# distill on old classes ($trainset \setminus cur$)
with torch.no_grad():
old_fmap = self.old_self.forward_layerwise(x)
loss_flat = F.cosine_embedding_loss(
new_fmap["features"],
old_fmap["features"].detach(),
torch.ones(x.shape[0]).to(self.device),
)
loss_spatial = pod_spatial_loss(old_fmap, new_fmap)

loss = loss_lsc + self.task_factor * (
self.lambda_spatial * loss_spatial + self.lambda_flat * loss_flat
)
else:
# first task, no distill
loss_spatial = None
loss_flat = None
loss = loss_lsc

self.log_dict(
{
"train/loss": loss,
"train/lsc": loss_lsc,
"train/flat": loss_flat or 0.0,
"train/spatial": loss_spatial or 0.0,
},
prog_bar=True,
on_epoch=True,
sync_dist=True,
)
return loss

def on_train_end(self):
# already implemented in ICaRL
dm = self.trainer.datamodule # ty: ignore[unresolved-attribute]

# update memory after training current task data, not after replay memory
if dm.train_filter_fn is None:
self.update_memory(dm)

def on_fit_end(self):
self.snapshot_old()
2 changes: 1 addition & 1 deletion tests/training/test_icarl_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import lightning as L
import pytest
import wandb
from lightning.pytorch.loggers import WandbLogger

import wandb
from lycil.constants import _EXP_NAME
from lycil.data.hfmodule import HFDataModule
from lycil.learner.icarl import ICaRL
Expand Down
2 changes: 1 addition & 1 deletion tests/training/test_lwf_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import lightning as L
import pytest
import wandb
from lightning.pytorch.loggers import WandbLogger

import wandb
from lycil.constants import _EXP_NAME
from lycil.data.hfmodule import HFDataModule
from lycil.learner.lwf import LWF
Expand Down
Loading