Skip to content

Commit

Permalink
[docs] adding docs to convnext
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzobattistela committed Aug 19, 2023
1 parent c48b95d commit dec8a1e
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tensorflow_similarity/architectures/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,52 @@ def ConvNeXtSim(
pooling: str = "gem",
gem_p: float = 3.0,
) -> SimilarityModel:
""""Build an ConvNeXt Model backbone for similarity learning
[A ConvNet for the 2020s](https://arxiv.org/pdf/2201.03545.pdf)
Args:
input_shape: Size of the input image. Must match size of ConvNeXt version you use.
See below for version input size.
embedding_size: Size of the output embedding. Usually between 64
and 512. Defaults to 128.
variant: Which Variant of the ConvNeXt to use. Defaults to "BASE".
weights: Use pre-trained weights - the only available currently being
imagenet. Defaults to "imagenet".
trainable: Make the ConvNeXt backbone fully trainable or partially
trainable.
- "full" to make the entire backbone trainable,
- "partial" to only make the last 3 block trainable
- "frozen" to make it not trainable.
l2_norm: If True and include_top is also True, then
tfsim.layers.MetricEmbedding is used as the last layer, otherwise
keras.layers.Dense is used. This should be true when using cosine
distance. Defaults to True.
include_top: Whether to include the fully-connected layer at the top
of the network. Defaults to True.
pooling: Optional pooling mode for feature extraction when
include_top is False. Defaults to gem.
- None means that the output of the model will be the 4D tensor
output of the last convolutional layer.
- avg means that global average pooling will be applied to the
output of the last convolutional layer, and thus the output of the
model will be a 2D tensor.
- max means that global max pooling will be applied.
- gem means that global GeneralizedMeanPooling2D will be applied.
The gem_p param sets the contrast amount on the pooling.
gem_p: Sets the power in the GeneralizedMeanPooling2D layer. A value
of 1.0 is equivalent to GlobalMeanPooling2D, while larger values
will increase the contrast between activations within each feature
map, and a value of math.inf will be equivalent to MaxPool2d.
"""
inputs = layers.Input(shape=input_shape)
x = inputs

Expand Down Expand Up @@ -71,6 +117,21 @@ def ConvNeXtSim(


def build_convnext(variant: str, weights: str | None = None, trainable: str = "full") -> tf.keras.Model:
"""Build the requested ConvNeXt
Args:
variant: Which Variant of the ConvNeXt to use.
weights: Use pre-trained weights - the only available currently being
imagenet.
trainable: Make the ConvNeXt backbone fully trainable or partially
trainable.
- "full" to make the entire backbone trainable,
- "partial" to only make the last 3 block trainable
- "frozen" to make it not trainable.
Returns:
The output layer of the convnext model
"""
convnext_fn = CONVNEXT_ARCHITECTURE[variant.upper()]
convnext = convnext_fn(weights=weights, include_top=False)

Expand Down

0 comments on commit dec8a1e

Please sign in to comment.