diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index a32ee9a..9bd6c8b 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -55,7 +55,9 @@ def __init__( super().__init__() self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model)) + stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths)) self.stages = nn.Sequential() + for stage_idx, depth in enumerate(depths): stage = nn.Sequential() if stage_idx > 0: @@ -71,8 +73,9 @@ def __init__( downsample = nn.Identity() stage.append(downsample) - for _ in range(depth): - block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, stochastic_depth, norm, act) + for block_idx in range(depth): + rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx] + block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act) stage.append(block) self.stages.append(stage)