A deep-dive portfolio project building supervised SimCSE from scratch. This repo covers the full pipeline: aligned NLI data processing, advanced contrastive loss, and a custom BERT-based model in PyTorch.
This project is a from-scratch implementation of the supervised portion of the SimCSE: Simple Contrastive Learning of Sentence Embeddings paper.
The repository documents the journey from a naive first attempt (Version 1) which suffered from a critical data-processing bug, to a baseline-fixing implementation (Version 2), and finally to a high-performance model (Version 3) that achieves a 0.8524 Spearman correlation on the STSb (Semantic Textual Similarity) benchmark.
This project is a case study in the importance of iterative development and applying SOTA techniques.
The initial embed-v1.ipynb notebook represented a first pass at the problem. It successfully set up the model, loss, and training loop.
The Bug: The data pipeline was flawed. It created three separate lists for premise, positives (entailment), and negatives (contradiction) and then zip()-ed them together. This resulted in mismatched (Anchor, Positive, Negative) triplets. The model was being trained on nonsensical data, where the positive and negative sentences had no logical connection to the anchor sentence.
The Result: The model failed to learn meaningful semantic representations, resulting in a very low Spearman correlation of 0.184 on the STSb test set.
# V1 Bug Example: Lists created independently
premise = [ex['premise'] for ex in dataset if ex['label'] in [0,2]]
positives = [ex['hypothesis'] for ex in dataset if ex['label'] == 0]
negatives = [ex['hypothesis'] for ex in dataset if ex['label'] == 2]
triplets = list(zip(premise, positives, negatives)) # ❌ MISMATCHED!The V2 implementation (embed_v2.ipynb) diagnosed and fixed this critical bug, establishing a solid baseline.
The Fix: The data pipeline was completely rebuilt using pandas.groupby to group the SNLI and MNLI datasets by premise. This created aligned triplets, ensuring that for every (Anchor: premise), the (Positive: entailment) and (Negative: contradiction) were both logically associated with it.
The Result: With the data pipeline corrected, the model was able to learn effectively, achieving a respectable baseline Spearman correlation of 0.628 on 200k samples.
# V2 Fix: Grouping by premise for aligned triplets
grouped = df.groupby('premise')
for premise, group in grouped:
positives = group[group['label'] == 0]['hypothesis'].tolist()
negatives = group[group['label'] == 2]['hypothesis'].tolist()
if positives and negatives:
triplets.append((premise, positives[0], negatives[0])) # ✅ ALIGNED!The V3 implementation (embed-v3.ipynb) was built to push the baseline score to a state-of-the-art level by incorporating three key engineering upgrades.
The training data was scaled up from 200k samples to the full SNLI (~550k) and MNLI (~392k) datasets (~940k total sentences), yielding a massive training set of 140,000+ aligned triplets.
CONFIG = {
"snli_samples": 550_152, # Full SNLI
"mnli_samples": 392_702, # Full MNLI
...
}The weak [CLS] token pooling from V2 was replaced with a Mean Pooling layer. This creates a much richer sentence embedding by averaging the last hidden state of all tokens (excluding padding), capturing the full semantic context of the sentence.
def mean_pooling(self, hidden_state, attention_mask):
# Expand mask and apply to hidden states
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
sum_embeddings = torch.sum(hidden_state * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask # Average over non-padding tokensThe simple 1-vs-1 triplet loss was upgraded to an InfoNCE loss with in-batch negatives. For an anchor sentence A[i], its "positive" is P[i], but its "negatives" become all other 127 samples in the batch (P[0...N] and N[0...N]). This 1-vs-127 comparison is a much harder task and forces the model to learn a highly discriminative embedding space.
# V3: In-batch negatives (batch_size = 64)
# For anchor A[i], compare against 128 candidates:
# - 1 true positive P[i]
# - 127 negatives: all other P[j] and N[j] in batch
all_embeddings = torch.cat([pos_embed, neg_embed], dim=0) # Shape: (128, 768)
similarity_matrix = anchor_embed @ all_embeddings.T # Shape: (64, 128)
# Target: P[i] is at index i in the concatenated tensor
labels = torch.arange(batch_size).to(device)
loss = F.cross_entropy(similarity_matrix / temperature, labels)The Result: The combination of these three techniques successfully pushed the model's performance to 0.8524 Spearman correlation, demonstrating a robust, high-performance, and "from-scratch" implementation.
-
Data Pipeline: Loads the full
snli(~550k) andglue/mnli(~392k) datasets. It filters for entailment (label 0) and contradiction (label 2), then groups bypremiseto create 140,000+ aligned (A, P, N) triplets. -
Model Architecture: A
bert-base-uncasedmodel with a custom mean pooling layer. This layer averages the last hidden state of all non-padding tokens to produce the final sentence embedding, which is then L2-normalized. -
Loss Function: An InfoNCE (contrastive) loss with in-batch negatives. For a batch size of 64, each anchor
A[i]is compared against a set of 128 candidates (64 positives and 64 negatives). The model uses a cross-entropy loss to find the single correct positiveP[i]from this set.
The iterative journey from V1 to V3 is shown in the performance on the Spearman Correlation score from the full STSb test set.
| Version | Model Architecture | Data (Aligned) | Spearman Score (STSb) |
|---|---|---|---|
| V1 (Buggy) | [CLS] Token | 10k SNLI (Mismatched) | 0.184 |
| V2 (Baseline) | [CLS] Token | 200k SNLI+MNLI (Aligned) | 0.628 |
| V3 (SOTA) | Mean Pooling + In-Batch Loss | Full SNLI+MNLI (Aligned) | 0.8524 |
git clone https://github.com/ron-42/simcse-from-scratch.git
cd simcse-from-scratchpip install -r requirements.txt- Open
embed-v3.ipynbin Kaggle or Google Colab - In Kaggle Settings:
- Set Accelerator to GPU (T4, P100, etc.)
- Turn Internet to ON
- Run the cells in the notebook
⚠️ Important: Cell 1 fixes Kaggle's environment. After running Cell 1, you must restart the session (Run > Restart Session) before running Cell 2 and the rest of the notebook.
- Open
embed_v2.ipynbin Jupyter, Colab, or Kaggle - Follow the same GPU setup steps
- Run all cells sequentially
- Open
embed-v1.ipynbto see the original buggy implementation - Included for educational purposes to demonstrate the importance of correct data alignment
torchtransformersdatasetspandastqdmscikit-learn(forcosine_similarity)scipy(forspearmanr)ipywidgets(for notebook progress bars)
- Implement the unsupervised version of SimCSE, which uses standard dropout as data augmentation to create positive pairs from the same sentence.
- Train on a larger, more diverse dataset (e.g., all of Wikipedia) to create a more general-purpose embedding model.
- Experiment with hard negative mining to further improve discriminative power.
- Deploy the model as a sentence embedding API for real-world applications.
This project is an implementation of the ideas presented in the original paper:
Gao, T., Yao, X., & Chen, D. (2021). SimCSE: Simple Contrastive Learning of Sentence Embeddings. arXiv preprint arXiv:2104.08821.
This project is open source and available under the MIT License.