-
Notifications
You must be signed in to change notification settings - Fork 193
Add question 238: SimCLR Contrastive Loss (NT-Xent) #578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BARALLL
wants to merge
1
commit into
Open-Deep-ML:main
Choose a base branch
from
BARALLL:new-question-238-contrastive
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| { | ||
| "id": "238", | ||
| "title": "SimCLR Contrastive Loss (NT-Xent)", | ||
| "difficulty": "medium", | ||
| "category": "Deep Learning", | ||
| "video": "", | ||
| "likes": "0", | ||
| "dislikes": "0", | ||
| "contributor": [ | ||
| { | ||
| "profile_link": "https://github.com/BARALLL", | ||
| "name": "baralm" | ||
| } | ||
| ], | ||
| "description": "\\## NT-Xent Loss for Self-Supervised Contrastive Learning\n\n\n\nIn self-supervised contrastive learning frameworks like \\*\\*SimCLR\\*\\*, we learn meaningful representations without labels by:\n\n1\\. Creating two augmented \"views\" of each image\n\n2\\. Training the model to recognize that views of the \\*\\*same\\*\\* image should have similar embeddings\n\n3\\. While views of \\*\\*different\\*\\* images should have dissimilar embeddings\n\n\n\n\\### The Problem\n\nYou are given a batch of $N$ images. For each image, we generate 2 augmented views, resulting in a batch size of $2N$. The embeddings are organized in an \\*\\*interleaved\\*\\* fashion:\n\n\\- Rows $2k$ and $2k+1$ are two views of the same image $k$ (a positive pair).\n\n\\- Any other pair of rows constitutes a negative pair.\n\n\n\nFor a specific sample $i$, let $j$ be its positive pair. The \\*\\*NT-Xent (Normalized Temperature-scaled Cross-Entropy)\\*\\* loss for sample $i$ is defined as:\n\n\n\n$$\n\n\\\\ell\\_i = -\\\\log \\\\frac{\\\\exp(\\\\text{sim}(z\\_i, z\\_j) / \\\\tau)}{\\\\sum\\_{k=1}^{2N} \\\\mathbb{1}\\_{\\[k \\\\neq i]} \\\\exp(\\\\text{sim}(z\\_i, z\\_k) / \\\\tau)}\n\n$$\n\n\n\nWhere:\n\n\\- $z$ is the batch of L2-normalized embeddings.\n\n\\- $\\\\text{sim}(u, v) = u^\\\\top v$ (Cosine similarity, since $u, v$ are normalized).\n\n\\- $\\\\mathbb{1}\\_{\\[k \\\\neq i]}$ is an indicator function (returns 1 if $k \\\\neq i$, else 0). Effectively, we sum over all samples except the sample itself.\n\n\\- $\\\\tau$ is the temperature parameter.\n\n\n\nThe total loss is the arithmetic mean over all $2N$ samples: $L = \\\\frac{1}{2N} \\\\sum\\_{i=0}^{2N-1} \\\\ell\\_i$.\n\n\n\n\\### Your Task\n\nImplement the function `nt\\_xent\\_loss(z, temperature)` that computes the NT-Xent loss using vectorized NumPy operations.\n\n\n\n\\*\\*Input Format\\*\\*\n\n\\- `z`: A numpy array of shape `(2N, embedding\\_dim)` containing \\*\\*L2-normalized\\*\\* embeddings.\n\n - \\*\\*Structure\\*\\*: The rows are interleaved such that `z\\[2k]` and `z\\[2k+1]` form a positive pair (two views of image $k$).\n\n - Visually: `\\[View1\\_Img1, View2\\_Img1, View1\\_Img2, View2\\_Img2, ...]`.\n\n - All other interactions `z\\[i]` and `z\\[j]` (where `j` is not the pair of `i`) are considered negatives.\n\n\\- `temperature`: A float scaling parameter ($\\\\tau > 0$).\n\n\n\n\\### Output Format\n\n\\- Returns `float`: The average NT-Xent loss over all $2N$ samples.\n\n\n\n\\### Note on Stability\n\n\\- You should implement the \\*\\*Log-Sum-Exp trick\\*\\* (subtracting the maximum value before exponentiation) to ensure numerical stability.\n\n\n\n\\### Constraints\n\n\\- $N \\\\geq 1$ (at least 1 image, so batch size $\\\\geq 2$)\n\n\\- `embedding\\_dim` $\\\\geq 1$\n\n\\- `temperature` $> 0$\n\n\\- Input embeddings are guaranteed to be L2-normalized\n\n\\- \\*\\*Performance:\\*\\* Avoid explicit `for` loops. Use matrix operations and broadcasting.", | ||
| "learn_section": "\n# Learn Section\n\n# Understanding NT-Xent Loss (Normalized Temperature-scaled Cross Entropy)\n\n### 1. The Intuition: \"Find Your Partner\"\nAt its core, Self-Supervised Learning (SSL) creates a \"pretext task\" from unlabeled data. \nImagine a crowded room of people (embeddings). Everyone has a generic twin. The goal of NT-Xent is to make you stand as close as possible to your twin (alignment), while pushing everyone else away (uniformity).\n\n* **Positive pairs**: Different views (augmentations) of the SAME image → should be **SIMILAR**.\n* **Negative pairs**: Views of DIFFERENT images → should be **DISSIMILAR**.\n\n### 2. Generating Views\nWe take a batch of $N$ images and generate 2 views for each, resulting in a batch size of $2N$.\n\n```text\nImage A ───────── Augment ──→ View A₁ ──┐\n │ ├── Should be near (Positive) ✓\n └───── Augment ───→ View A₂ ──┘\n\nImage B ───────── Augment ──→ View B₁ ──┐\n ├── Should be far (Negative) ✗\nImage A ───────── Augment ──→ View A₁ ──┘\n```\n\n### 3. The Math: A Classification Problem\nThe NT-Xent loss is essentially a **Softmax Cross-Entropy** loss. We treat the positive pair as the \"correct class\" and all other images in the batch as \"negative classes.\"\n\n$$\\ell_i = -\\log \\frac{\\exp(\\text{sim}(z_i, z_j) / \\tau)}{\\sum_{k \\neq i} \\exp(\\text{sim}(z_i, z_k) / \\tau)}$$\n\n**Breakdown:**\n* **Numerator**: The score of the positive pair ($z_i, z_j$). We want this high.\n* **Denominator**: The sum of scores of $z_i$ against *all* other samples (negatives + positive).\n* **Goal**: By maximizing the numerator relative to the denominator, we force the model to learn unique features that distinguish sample $i$ from the crowd.\n\n### 4. Visualizing the Similarity Matrix\nThe implementation relies on a $(2N \\times 2N)$ similarity matrix. If we organize our batch as `[Cat1, Cat2, Dog1, Dog2]`, the ideal matrix looks like this:\n\n$$\n\\begin{bmatrix}\n\\text{Mask} & \\mathbf{\\text{High}} & \\text{Low} & \\text{Low} \\\\\n\\mathbf{\\text{High}} & \\text{Mask} & \\text{Low} & \\text{Low} \\\\\n\\text{Low} & \\text{Low} & \\text{Mask} & \\mathbf{\\text{High}} \\\\\n\\text{Low} & \\text{Low} & \\mathbf{\\text{High}} & \\text{Mask}\n\\end{bmatrix}\n$$\n\n1. **The Diagonal (Masked)**: We explicitly ignore comparing an image to itself ($k \\neq i$).\n2. **The Off-diagonals**: These are the **Positive Pairs**. We want to maximize these values.\n3. **Everything else**: These are **Negatives**. We want to minimize these values.\n\n### 5. The Critical Role of Temperature ($\\tau$)\nThe temperature $\\tau$ scales the dot products before the softmax. It controls how much the model focuses on difficult examples.\n\n* **High $\\tau$ (e.g., 1.0)**: The distribution is smoother. The model treats all negatives roughly equally.\n* **Low $\\tau$ (e.g., 0.1)**: The distribution becomes sharp/peaky. The model ignores easy negatives and focuses heavily on **\"Hard Negatives\"** (images that look similar to the anchor but aren't).\n\n$$\\text{As } \\tau \\to 0: \\text{Loss approaches argmax (winner-take-all)}$$\n\n### 6. Why It Works: Alignment & Uniformity\nResearch shows this loss optimizes two specific geometric properties on the embedding hypersphere:\n1. **Alignment**: Two views of the same image map to nearby points.\n2. **Uniformity**: Feature vectors spread roughly uniformly across the sphere. This prevents **feature collapse**, where the model maps all images to the same constant vector to cheat the loss.\n\n### 7. Implementation Steps\n1. **Forward Pass**: Get normalized embeddings $z$ (shape $2N \\times D$).\n2. **Similarity**: Compute matrix $S = z \\cdot z^T$ (shape $2N \\times 2N$).\n3. **Scale**: Divide $S$ by $\\tau$.\n4. **Mask**: Set diagonal values to $-\\infty$ (so exp() becomes 0).\n5. **Labels**: Create target labels. If batch is organized as `[View1_A, View1_B, ..., View2_A, View2_B...]`, then $i$ matches with $i + N$.\n6. **Loss**: Apply Standard Cross Entropy.\n\n### 6. Numerical Stability (Log-Sum-Exp Trick)\nComputers struggle with large exponents. If $\\text{sim}=1.0$ and $\\tau=0.01$, then $e^{100}$ is huge. To prevent overflow, we use the identity:\n\n$$ \\log \\left( \\sum e^{x_i} \\right) = a + \\log \\left( \\sum e^{x_i - a} \\right) $$\n\nwhere $a = \\max(x)$.\nBy subtracting the maximum value from the logits before exponentiating, the largest term becomes $e^0 = 1$, preventing overflow while keeping the probabilities mathematically identical.\n\n### 8. Connection to InfoNCE\nNT-Xent is a specific form of InfoNCE (Noise Contrastive Estimation) loss, which has theoretical connections to maximizing mutual information between views.\n\n### References\n* **SimCLR Paper**: [Chen et al., 2020](https://arxiv.org/abs/2002.05709)\n ", | ||
| "starter_code": "def nt_xent_loss(z: np.ndarray, temperature: float) -> float:\n \"\"\"\n Compute the NT-Xent loss for contrastive learning.\n \n Args:\n z: L2-normalized embeddings, shape (2N, embedding_dim)\n Positive pairs: z[2k] and z[2k+1] are views of image k\n temperature: Temperature scaling parameter (τ > 0)\n \n Returns:\n The scalar NT-Xent loss value\n \"\"\"\n pass", | ||
| "solution": "def nt_xent_loss(z: np.ndarray, temperature: float) -> float:\n N = z.shape[0]\n sim = (z @ z.T) / temperature\n \n sim_exp = np.exp(sim - np.max(sim, axis=1, keepdims=True))\n mask_diag = ~np.eye(N, dtype=bool)\n denominator = np.sum(sim_exp * mask_diag, axis=1)\n \n indices = np.arange(N)\n pos_indices = indices + 1 - 2 * (indices % 2)\n numerator = sim_exp[indices, pos_indices]\n \n losses = -np.log(numerator / denominator)\n \n return float(np.mean(losses))", | ||
| "example": { | ||
| "input": "# N=2 (Total batch 4).\n# 0 and 1 are views of Cat. 2 and 3 are views of Dog.\n# 0 matches 1 (Positive). 0 mismatches 2 and 3 (Negatives).\n\nz = np.array([\n [1.0, 0.0], # 0: Cat View A\n [1.0, 0.0], # 1: Cat View B (Perfect match with 0)\n [0.0, 1.0], # 2: Dog View A (Orthogonal to 0)\n [0.0, 1.0] # 3: Dog View B (Orthogonal to 0)\n])\ntemperature = 0.5", | ||
| "output": "2.2395447662218846", | ||
| "reasoning": "Let's calculate loss for index 0 (Cat A):\n\n1. **Positive Pair**: Index 1. Sim = 1.0.\n\n2. **Negatives**: Index 2, Index 3. Sim = 0.0.\n\n3. **Terms**:\n\n - Numerator (Positive): $\\exp(1.0 / 0.5) = \\exp(2) \\approx 7.389$\n\n - Denominator (All $k \\neq 0$):\n\n - Term 1 (Pos): $\\exp(1.0/0.5) \\approx 7.389$\n\n - Term 2 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Term 3 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Denominator Sum: $7.389 + 1 + 1 = 9.389$\n\n4. **Loss for index 0**: $-\\log(7.389 / 9.389) \\approx 0.239$\n\n\n\nSince the setup is symmetric, all 4 indices have the same loss." | ||
| }, | ||
| "test_cases": [ | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 1.0))", | ||
| "expected_output": "0.5514447139320511" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0]], 1.0))", | ||
| "expected_output": "0.0" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [1, 0], [1, 0]], 1.0))", | ||
| "expected_output": "1.0986122886681098" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [-1, 0], [-1, 0]], 0.5))", | ||
| "expected_output": "0.035976299748193295" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], 1.0))", | ||
| "expected_output": "1.0986122886681098" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 100.0))", | ||
| "expected_output": "1.0919567454272663" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 0.5))", | ||
| "expected_output": "0.23954476622188456" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 2.0))", | ||
| "expected_output": "0.7943767694176431" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], 1.0))", | ||
| "expected_output": "0.904832441554448" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [0.8, 0.6], [0, 1], [-0.6, 0.8]], 1.0))", | ||
| "expected_output": "0.6735767888870939" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [-1, 0], [0, 1], [0, -1]], 1.0))", | ||
| "expected_output": "1.861994804058251" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1, 0], [0, 1], [0.99, 0.141], [0.141, 0.99]], 1.0))", | ||
| "expected_output": "1.4700659232878173" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1], [1]], 1.0))", | ||
| "expected_output": "0.0" | ||
| }, | ||
| { | ||
| "test": "print(nt_xent_loss([[1.0, 0.0], [0.707, 0.707], [-1.0, 0.0], [-0.707, 0.707]], 1.0))", | ||
| "expected_output": "0.4528130954640332" | ||
| }, | ||
| { | ||
| "test": "import numpy as np; np.random.seed(42); N = 1000; dim = 64; temperature = 0.1; embeddings = np.random.randn(2 * N, dim); z = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True); print(nt_xent_loss(z, temperature))", | ||
| "expected_output": "8.37825890387809" | ||
| }, | ||
| { | ||
| "test": "import numpy as np; np.random.seed(42); N = 8; dim = 8192; temperature = 0.5; embeddings = np.random.randn(2 * N, dim); z = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True); print(nt_xent_loss(z, temperature))", | ||
| "expected_output": "2.7080214605287156" | ||
| } | ||
| ] | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| \## NT-Xent Loss for Self-Supervised Contrastive Learning | ||
|
|
||
|
|
||
|
|
||
| In self-supervised contrastive learning frameworks like \*\*SimCLR\*\*, we learn meaningful representations without labels by: | ||
|
|
||
| 1\. Creating two augmented "views" of each image | ||
|
|
||
| 2\. Training the model to recognize that views of the \*\*same\*\* image should have similar embeddings | ||
|
|
||
| 3\. While views of \*\*different\*\* images should have dissimilar embeddings | ||
|
|
||
|
|
||
|
|
||
| \### The Problem | ||
|
|
||
| You are given a batch of $N$ images. For each image, we generate 2 augmented views, resulting in a batch size of $2N$. The embeddings are organized in an \*\*interleaved\*\* fashion: | ||
|
|
||
| \- Rows $2k$ and $2k+1$ are two views of the same image $k$ (a positive pair). | ||
|
|
||
| \- Any other pair of rows constitutes a negative pair. | ||
|
|
||
|
|
||
|
|
||
| For a specific sample $i$, let $j$ be its positive pair. The \*\*NT-Xent (Normalized Temperature-scaled Cross-Entropy)\*\* loss for sample $i$ is defined as: | ||
|
|
||
|
|
||
|
|
||
| $$ | ||
|
|
||
| \\ell\_i = -\\log \\frac{\\exp(\\text{sim}(z\_i, z\_j) / \\tau)}{\\sum\_{k=1}^{2N} \\mathbb{1}\_{\[k \\neq i]} \\exp(\\text{sim}(z\_i, z\_k) / \\tau)} | ||
|
|
||
| $$ | ||
|
|
||
|
|
||
|
|
||
| Where: | ||
|
|
||
| \- $z$ is the batch of L2-normalized embeddings. | ||
|
|
||
| \- $\\text{sim}(u, v) = u^\\top v$ (Cosine similarity, since $u, v$ are normalized). | ||
|
|
||
| \- $\\mathbb{1}\_{\[k \\neq i]}$ is an indicator function (returns 1 if $k \\neq i$, else 0). Effectively, we sum over all samples except the sample itself. | ||
|
|
||
| \- $\\tau$ is the temperature parameter. | ||
|
|
||
|
|
||
|
|
||
| The total loss is the arithmetic mean over all $2N$ samples: $L = \\frac{1}{2N} \\sum\_{i=0}^{2N-1} \\ell\_i$. | ||
|
|
||
|
|
||
|
|
||
| \### Your Task | ||
|
|
||
| Implement the function `nt\_xent\_loss(z, temperature)` that computes the NT-Xent loss using vectorized NumPy operations. | ||
|
|
||
|
|
||
|
|
||
| \*\*Input Format\*\* | ||
|
|
||
| \- `z`: A numpy array of shape `(2N, embedding\_dim)` containing \*\*L2-normalized\*\* embeddings. | ||
|
|
||
| - \*\*Structure\*\*: The rows are interleaved such that `z\[2k]` and `z\[2k+1]` form a positive pair (two views of image $k$). | ||
|
|
||
| - Visually: `\[View1\_Img1, View2\_Img1, View1\_Img2, View2\_Img2, ...]`. | ||
|
|
||
| - All other interactions `z\[i]` and `z\[j]` (where `j` is not the pair of `i`) are considered negatives. | ||
|
|
||
| \- `temperature`: A float scaling parameter ($\\tau > 0$). | ||
|
|
||
|
|
||
|
|
||
| \### Output Format | ||
|
|
||
| \- Returns `float`: The average NT-Xent loss over all $2N$ samples. | ||
|
|
||
|
|
||
|
|
||
| \### Note on Stability | ||
|
|
||
| \- You should implement the \*\*Log-Sum-Exp trick\*\* (subtracting the maximum value before exponentiation) to ensure numerical stability. | ||
|
|
||
|
|
||
|
|
||
| \### Constraints | ||
|
|
||
| \- $N \\geq 1$ (at least 1 image, so batch size $\\geq 2$) | ||
|
|
||
| \- `embedding\_dim` $\\geq 1$ | ||
|
|
||
| \- `temperature` $> 0$ | ||
|
|
||
| \- Input embeddings are guaranteed to be L2-normalized | ||
|
|
||
| \- \*\*Performance:\*\* Avoid explicit `for` loops. Use matrix operations and broadcasting. | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| { | ||
| "input": "# N=2 (Total batch 4).\n# 0 and 1 are views of Cat. 2 and 3 are views of Dog.\n# 0 matches 1 (Positive). 0 mismatches 2 and 3 (Negatives).\n\nz = np.array([\n [1.0, 0.0], # 0: Cat View A\n [1.0, 0.0], # 1: Cat View B (Perfect match with 0)\n [0.0, 1.0], # 2: Dog View A (Orthogonal to 0)\n [0.0, 1.0] # 3: Dog View B (Orthogonal to 0)\n])\ntemperature = 0.5", | ||
| "output": "2.2395447662218846", | ||
| "reasoning": "Let's calculate loss for index 0 (Cat A):\n\n1. **Positive Pair**: Index 1. Sim = 1.0.\n\n2. **Negatives**: Index 2, Index 3. Sim = 0.0.\n\n3. **Terms**:\n\n - Numerator (Positive): $\\exp(1.0 / 0.5) = \\exp(2) \\approx 7.389$\n\n - Denominator (All $k \\neq 0$):\n\n - Term 1 (Pos): $\\exp(1.0/0.5) \\approx 7.389$\n\n - Term 2 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Term 3 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Denominator Sum: $7.389 + 1 + 1 = 9.389$\n\n4. **Loss for index 0**: $-\\log(7.389 / 9.389) \\approx 0.239$\n\n\n\nSince the setup is symmetric, all 4 indices have the same loss." | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be .2395 not 2.395