Skip to content

Commit

Permalink
Major code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
DarshanDeshpande committed Feb 18, 2022
1 parent a314f90 commit 56928ac
Show file tree
Hide file tree
Showing 20 changed files with 345 additions and 285 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ git clone https://github.com/DarshanDeshpande/jax-models.git
To see all model architectures available:

```py
from jax_models.models.model_registry import list_models
from jax_models import list_models
from pprint import pprint

pprint(list_models())
Expand All @@ -95,7 +95,7 @@ pprint(list_models())
To load your desired model:

```py
from jax_models.models.model_registry import load_model
from jax_models import load_model
load_model('swin-tiny-224', attach_head=True, num_classes=1000, dropout=0.0, pretrained=True)
```

Expand Down
1 change: 1 addition & 0 deletions jax_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import activations
from . import layers
from . import models
from .models.model_registry import list_models, load_model
74 changes: 54 additions & 20 deletions jax_models/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@

from ..layers import DropPath, TransformerMLP, PatchEmbed
from .helper import download_checkpoint_params, load_trained_params
from .model_registry import register_model

from typing import Optional

__all__ = [
"CaiT",
"XXS24_224",
"XXS24_384",
"XXS36_224",
"XXS36_384",
"XS24_384",
"S24_224",
"S24_384",
"S36_384",
"M36_384",
"M48_448",
"cait_xxs24_224",
"cait_xxs24_384",
"cait_xxs36_224",
"cait_xxs36_384",
"cait_xs24_384",
"cait_s24_224",
"cait_s24_384",
"cait_s36_384",
"cait_m36_384",
"cait_m48_448",
]

pretrained_cfgs = {
Expand Down Expand Up @@ -207,6 +209,28 @@ def __call__(self, x, deterministic=None):


class CaiT(nn.Module):
"""
Module for Class-Attention in Image Transformers
Attributes:
patch_size (int): Patch size. Default is 16.
embed_dim (int): Embedding dimension. Default is 768.
depth (int): Number of blocks. Default is 12.
num_heads (int): Number of attention heads. Default is 12.
mlp_ratio (int): Multiplier for hidden dimension in transformer MLP block. Default is 4.
use_att_bias (bool): Whether to use bias for linear qkv projection. Default is True.
drop (float): Dropout value. Default is 0.
att_dropout (float): Dropout value for attention Default is 0.
drop_path (float): Dropout value for DropPath. Default is 0.
init_scale (float): Initialization scale used for gamma initialization. Default is 1e-4.
depth_token_only (int): Number of blocks with cls_token and class attention. Default is 2.
mlp_ratio_clstk (int): Multiplier for hidden dimension in transformer MLP block with class attention. Default is 4.
attach_head (bool): Whether to attach classification head. Default is True
num_classes (int): Number of classification classes. Only works if attach_head is True. Default is 1000.
deterministic (bool): Optional argument, if True, network becomes deterministic and dropout is not applied.
"""

patch_size: int = 16
embed_dim: int = 768
depth: int = 12
Expand Down Expand Up @@ -281,7 +305,8 @@ def __call__(self, inputs, deterministic=None):
return x


def XXS24_224(
@register_model
def cait_xxs24_224(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -310,7 +335,8 @@ def XXS24_224(
return model, params


def XXS24_384(
@register_model
def cait_xxs24_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -339,7 +365,8 @@ def XXS24_384(
return model, params


def XXS36_224(
@register_model
def cait_xxs36_224(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -368,7 +395,8 @@ def XXS36_224(
return model, params


def XXS36_384(
@register_model
def cait_xxs36_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -397,7 +425,8 @@ def XXS36_384(
return model, params


def XS24_384(
@register_model
def cait_xs24_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -426,7 +455,8 @@ def XS24_384(
return model, params


def S24_224(
@register_model
def cait_s24_224(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -455,7 +485,8 @@ def S24_224(
return model, params


def S24_384(
@register_model
def cait_s24_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -484,7 +515,8 @@ def S24_384(
return model, params


def S36_384(
@register_model
def cait_s36_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -513,7 +545,8 @@ def S36_384(
return model, params


def M36_384(
@register_model
def cait_m36_384(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down Expand Up @@ -542,7 +575,8 @@ def M36_384(
return model, params


def M48_448(
@register_model
def cait_m48_448(
attach_head=True,
num_classes=1000,
dropout=0.0,
Expand Down
21 changes: 13 additions & 8 deletions jax_models/models/conv_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import flax.linen as nn

from ..layers import DepthwiseConv2D
from .model_registry import register_model

from typing import Optional
import logging
Expand All @@ -10,10 +11,10 @@

__all__ = [
"ConvMixer",
"ConvMixer_512_12",
"ConvMixer_768_32",
"ConvMixer_1024_20",
"ConvMixer_1536_20",
"convmixer_512_12",
"convmixer_768_32",
"convmixer_1024_20",
"convmixer_1536_20",
]


Expand Down Expand Up @@ -79,7 +80,8 @@ def __call__(self, inputs, deterministic=None):
return x


def ConvMixer_1536_20(
@register_model
def convmixer_1536_20(
attach_head=False,
num_classes=1000,
dropout=None,
Expand All @@ -94,7 +96,8 @@ def ConvMixer_1536_20(
return ConvMixer(1536, 7, 20, 9, attach_head, num_classes, **kwargs)


def ConvMixer_768_32(
@register_model
def convmixer_768_32(
attach_head=False,
num_classes=1000,
dropout=None,
Expand All @@ -109,7 +112,8 @@ def ConvMixer_768_32(
return ConvMixer(768, 7, 32, 7, attach_head, num_classes, **kwargs)


def ConvMixer_512_12(
@register_model
def convmixer_512_12(
attach_head=False,
num_classes=1000,
dropout=None,
Expand All @@ -124,7 +128,8 @@ def ConvMixer_512_12(
return ConvMixer(512, 7, 12, 8, attach_head, num_classes, **kwargs)


def ConvMixer_1024_20(
@register_model
def convmixer_1024_20(
attach_head=False,
num_classes=1000,
dropout=None,
Expand Down
Loading

0 comments on commit 56928ac

Please sign in to comment.