AuthorsAssran, Duval, Misra, Bojanowski, Vincent, Rabbat, LeCun, Ballas
Date2023-01
CategoryVision
Derives fromJEPA
Score6.65/10 — Draft

Image JEPA (I-JEPA)

1. Introduction

Self-supervised learning in computer vision has long been dominated by two paradigms: invariance-based methods (such as DINO, BYOL, and SimCLR), which learn by comparing augmented views of the same image, and generative methods (such as MAE), which learn by reconstructing masked pixels. Both paradigms carry fundamental limitations. Invariance-based methods require carefully hand-crafted augmentation pipelines — color jitter, Gaussian blur, solarization, multi-crop — and the resulting representations are biased toward whichever invariances the augmentations encode. Generative methods like MAE reconstruct raw pixels, forcing the encoder to waste capacity modeling low-level texture and noise that carry little semantic signal.

I-JEPA (Image-based Joint-Embedding Predictive Architecture), introduced by Assran, Duval, Misra, Bojanowski, Vincent, Rabbat, LeCun, and Ballas in January 2023, resolves both problems by operating entirely in learned representation space. Rather than predicting pixels, I-JEPA predicts the abstract representations of masked image regions, as produced by an exponential moving average (EMA) target encoder. This single shift — from pixel prediction to latent prediction — yields representations that are more semantic, more compute-efficient to learn, and less dependent on data augmentation.

I-JEPA is the first concrete instantiation of the Joint-Embedding Predictive Architecture framework for static images. The original JEPA position paper (LeCun, 2022) outlined a conceptual blueprint: learn representations by predicting in embedding space rather than input space, using an energy-based formulation with a learned predictor and a target encoder updated via EMA. I-JEPA translates this blueprint into a working system with three critical design decisions: (1) a multi-block masking strategy that forces prediction of spatially distributed, semantically rich regions; (2) a narrow transformer predictor that prevents the encoder from encoding trivially predictable information; and (3) minimal data augmentation — only random crop and horizontal flip — demonstrating that the architecture's inductive biases alone suffice for learning strong representations.

The results are striking. A ViT-H/14 encoder trained with I-JEPA for 300 epochs on ImageNet-1K achieves 80.9% top-1 linear probing accuracy, outperforming MAE's ViT-H by over 4 percentage points (76.6%) despite MAE requiring 1600 epochs and substantially more compute. In low-shot settings (1% labeled ImageNet), I-JEPA's advantage widens further, reaching approximately 72% versus MAE's 49% — a gap that underscores the superior semantic quality of I-JEPA's features.

In this article, we first describe the high-level method and intuition behind latent prediction (Section 2), then present the complete architecture with annotated diagrams (Section 3). We dissect each component — encoder, target encoder, predictor, masking strategy, and loss function — in Section 4. Sections 5 and 6 provide exhaustive implementation details and formal algorithms. We walk through the training procedure (Section 7) and inference protocol (Section 8) with full SVG diagrams. Section 9 presents benchmark results and ablation studies. Section 10 situates I-JEPA within the broader JEPA family, and Section 11 summarizes key takeaways.

2. Method

The central idea of I-JEPA can be stated simply: mask large portions of an image, encode the visible remainder, and predict the representations of the masked regions without ever looking at their pixels.

Think of it this way: Imagine you are shown a photograph with several rectangular patches blacked out. A pixel-reconstruction approach (MAE) would ask you to paint in every missing pixel — the exact color of every blade of grass, every thread on a shirt. I-JEPA instead asks: "describe what is in each blacked-out region, at an abstract level." You might say "the top-left region contains sky," "the center region shows a dog's face." You don't need to reproduce exact pixels — you need to understand the scene's semantics. This is the fundamental difference between generative and joint-embedding predictive approaches.

More concretely, the method proceeds in five steps:

  1. Patchify: The input image (e.g., 224×224 pixels) is divided into a grid of non-overlapping patches (e.g., 14×14 pixel patches, yielding a 16×16 grid of 256 patches).
  2. Multi-block masking: Four rectangular target blocks are randomly sampled on the patch grid, each covering 15–20% of all patches. The remaining visible patches form the context. This aggressive masking removes roughly 70–85% of all patches.
  3. Context encoding: Only the visible (context) patches are fed through a Vision Transformer — the context encoder — producing a set of patch-level representations.
  4. Latent prediction: A smaller predictor transformer takes the context representations along with learnable mask tokens (positioned at target locations) and predicts what the representations of the missing patches should be.
  5. Loss computation: Simultaneously, the complete image (all patches) is fed through a target encoder — an EMA copy of the context encoder — producing ground-truth representations for the target patches. The L2 distance between predicted and actual target representations is minimized.

Gradients flow only through the context encoder and predictor; the target encoder is updated solely via exponential moving average, with no gradients. This asymmetry, combined with the narrow bottleneck of the predictor, prevents representational collapse — the degenerate solution where all representations become identical.

Common misconception: I-JEPA does NOT reconstruct pixels. It is not a masked autoencoder. The target space is the output of a learned encoder, not the raw image. This means the target representations can filter out unpredictable low-level details (high-frequency noise, texture variations) and retain only semantic content — precisely the information useful for downstream tasks.

3. Model Overview

Architecture Diagram

I-JEPA Training Architecture Input Image 224×224×3 Patchify + PosEmbed N=256 patches Multi-block Masking context ~40-75 patches all 256 patches Context Encoder ViT-H/14 (trainable) 32 layers, D=1280 Target Encoder ViT-H/14 (EMA) stop-gradient EMA m: 0.996→1.0 B×N_c×D Predictor Narrow ViT (trainable) 12 layers, D=384 Mask Tokens + PosEmbed B×M×D Predicted ŝ_y B×N×D Target sg(s_y) extract target positions L2 Loss ||ŝ_y - sg(s_y)||² ∇ backprop Trainable path (gradient flows) EMA / stop-gradient path N_c = context patches (~40-75) M = target patches (~180-215) | D = 1280
Figure 1: I-JEPA training architecture. The context encoder processes only visible patches. The target encoder (EMA, dashed) processes all patches. The predictor maps context representations to predicted target representations. L2 loss is computed between predictions and stop-gradient targets. Gradients (green) flow only through the predictor and context encoder.

At-a-Glance

PropertyValue
Input typeImage patches (non-overlapping, 14×14 or 16×16 pixels)
Masking strategyMulti-block: 4 rectangular target blocks, each 15–20% of patches
Encoder architectureVision Transformer (ViT-L/16, ViT-H/14, ViT-H/16, ViT-G/16)
Predictor typeNarrow Vision Transformer (12 layers, 384-dim, 12 heads)
Loss functionL2 (MSE) in representation space: $\frac{1}{M}\sum_{i=1}^{M}\|\hat{s}_y^{(i)} - \text{sg}(s_y^{(i)})\|_2^2$
Key result80.9% linear probing on ImageNet-1K (ViT-H/14, 300 epochs)
Parameters~632M (ViT-H encoder) + ~40M (predictor)

4. Main Components of I-JEPA

4.1 Context Encoder

The context encoder is a standard Vision Transformer (ViT) that processes only the visible (unmasked) patches. Its job is to produce rich, semantic representations from a partial view of the image.

Architecture details: The primary configuration uses ViT-H/14 — a ViT-Huge model with 14×14 pixel patches. For a 224×224 image, this yields a 16×16 grid of $N = 256$ patches. Each patch is linearly embedded into a $D = 1280$-dimensional vector via a learned projection (a single linear layer applied to the flattened $14 \times 14 \times 3 = 588$-dimensional patch). Standard learnable positional embeddings are added to each patch token.

The transformer backbone consists of 32 transformer blocks, each containing multi-head self-attention (16 heads, head dimension 80) and an MLP feedforward network (hidden dimension $4D = 5120$). Layer normalization is applied before each sub-layer (pre-norm architecture).

Crucially, the context encoder receives only the context patches — typically 40–75 out of 256 total patches (15–30% of the image). This is different from MAE, where the encoder also processes mask tokens. By excluding mask tokens from the encoder, I-JEPA achieves significant computational savings: the encoder's self-attention cost scales quadratically with sequence length, so processing 50 tokens instead of 256 reduces attention compute by roughly $25\times$.

Why this design: Processing only visible patches forces the encoder to form representations that are complete enough to support downstream prediction of the missing regions. The encoder cannot rely on positional placeholders for masked regions; it must extract all necessary semantic information from the available context alone. This inductive bias produces features that are inherently more semantic and transferable.

4.2 Target Encoder

The target encoder is architecturally identical to the context encoder — the same ViT-H/14 with 32 layers and $D = 1280$. However, it differs in two critical ways: (1) it processes the complete image (all 256 patches), and (2) its parameters are updated exclusively via exponential moving average (EMA) of the context encoder's parameters.

EMA update rule:

$$\bar{\theta} \leftarrow m \cdot \bar{\theta} + (1 - m) \cdot \theta$$

where $\bar{\theta}$ denotes the target encoder parameters, $\theta$ denotes the context encoder parameters, and $m$ is the momentum coefficient.

Momentum schedule: The momentum $m$ follows a cosine schedule from $m_0 = 0.996$ to $m_T = 1.0$ over the course of training:

$$m(t) = 1 - (1 - m_0) \cdot \frac{1}{2}\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)$$

where $t$ is the current training step and $T$ is the total number of training steps. Early in training, $m \approx 0.996$ allows the target encoder to track the context encoder relatively closely. As training progresses, $m \to 1.0$ and the target encoder becomes increasingly frozen, providing stable prediction targets.

Stop-gradient: No gradients are backpropagated through the target encoder. The target representations are treated as fixed constants during each optimization step. This is essential — without stop-gradient, the model would trivially collapse by making both encoders output the same constant representation for all inputs.

Why EMA: The EMA target encoder serves two roles. First, it provides slowly evolving prediction targets that are more stable than the rapidly changing context encoder, reducing oscillation during training. Second, because the target encoder sees the full image (not just context patches), its representations encode the complete scene semantics — providing rich supervision for the predictor. The combination of EMA smoothing and stop-gradient prevents the collapse modes that plague naive Siamese architectures.

4.3 Predictor

The predictor is a narrow Vision Transformer — substantially smaller than the encoder — that maps context representations to predicted target representations. Its narrowness is a deliberate design choice, not merely a compute optimization.

Architecture: 12 transformer blocks, embedding dimension $D_p = 384$, 12 attention heads (head dimension 32), MLP hidden dimension $4 \times 384 = 1536$. This gives roughly 40M parameters — about 15× smaller than the ViT-H encoder.

Input processing: The predictor receives two types of tokens:

  1. Context tokens: The output representations from the context encoder (dimension $D = 1280$) are linearly projected down to the predictor's dimension ($D_p = 384$).
  2. Mask tokens: For each target patch position, a learnable mask token vector ($D_p = 384$) is created. Positional embeddings corresponding to the target patch positions are added to these mask tokens, so the predictor knows where in the image each target patch belongs.

Both sets of tokens (context + mask) are concatenated along the sequence dimension and processed jointly through the predictor's 12 transformer blocks. The self-attention mechanism allows mask tokens to attend to context tokens (and each other), enabling the predictor to "fill in" representations at target positions using information from visible regions.

After processing, only the output tokens at target positions are extracted and linearly projected back to the encoder dimension ($D = 1280$) for loss computation.

Why narrow: The predictor's bottleneck dimension ($384$ vs. the encoder's $1280$) is critical for representation quality. If the predictor were as wide as the encoder, it could simply memorize a mapping from context positions to target positions, allowing the encoder to encode trivially simple features. The narrow bottleneck forces the context encoder to provide rich, compressed representations that the predictor can usefully transform. Ablation studies confirm this: increasing predictor width degrades downstream performance, while reducing predictor depth (from 12 to 6 layers) costs approximately 0.5–1% linear probing accuracy. A linear predictor (0 layers) degrades performance by 3–4%, confirming the necessity of nonlinear predictive computation.

4.4 Masking Strategy

The multi-block masking strategy is one of I-JEPA's most important contributions. It determines which patches the model must predict and which it can see, directly shaping the representations learned.

Target block sampling: Four rectangular target blocks are independently sampled on the patch grid. For each block:

  • Scale: sampled uniformly from $[0.15, 0.2]$ — each block covers 15–20% of total patches
  • Aspect ratio: sampled uniformly from $[0.75, 1.5]$ — roughly square to mildly rectangular
  • Position: center position sampled uniformly; block boundaries clipped to the patch grid

Context formation: The context is the complement of the union of all four target blocks. Since each block removes 15–20% of patches and blocks may partially overlap, the context typically retains 15–30% of all patches — a very aggressive masking ratio.

Multi-Block Masking Strategy Original Patch Grid (16×16) T1 T2 T3 T4 Context (visible to encoder) Prediction Targets predict predict predict predict Masking Parameters: • 4 target blocks, each covering 15–20% of patches (scale ∈ [0.15, 0.2]) • Aspect ratio ∈ [0.75, 1.5] — roughly square to mildly rectangular • Context retains ~15–30% of patches — aggressive masking forces semantic encoding • Blocks may overlap — union determines total masked area
Figure 2: Multi-block masking strategy. Left: original patch grid with four randomly sampled target blocks (T1–T4). Center: the context visible to the encoder (target regions removed). Right: the target regions whose representations must be predicted.

Why multi-block masking works: The paper provides extensive ablations comparing masking strategies. Random token masking (as in MAE) produces local, texture-level prediction tasks — each missing patch can be inferred from its immediate neighbors. Single large block masking creates a prediction task that is too spatially localized. Multi-block masking with four spatially distributed targets forces the encoder to capture global scene semantics: predicting a region in the top-left and another in the bottom-right of the image requires understanding the overall scene structure, not just local texture patterns.

The aggressive masking ratio (retaining only 15–30% of patches) further ensures that predictions cannot rely on simple interpolation from nearby visible patches. The encoder must compress the limited visible context into representations rich enough to support prediction of distant, disjoint regions.

4.5 Loss Function

I-JEPA minimizes the mean squared error (MSE) between predicted and actual target representations, computed at the individual patch level:

$$\mathcal{L} = \frac{1}{M} \sum_{i=1}^{M} \left\| \hat{s}_y^{(i)} - \text{sg}\left(s_y^{(i)}\right) \right\|_2^2$$

where:

  • $M$ = total number of target patches across all four target blocks
  • $\hat{s}_y^{(i)} = g_\phi\left(f_\theta(x_{\text{ctx}}), \mathbf{p}_i\right) \in \mathbb{R}^D$ — the predictor's output at target position $i$, computed from context encoder output $f_\theta(x_{\text{ctx}})$ and positional embedding $\mathbf{p}_i$
  • $s_y^{(i)} = f_{\bar{\theta}}(x)_i \in \mathbb{R}^D$ — the target encoder's representation at position $i$, extracted from the full-image encoding
  • $\text{sg}(\cdot)$ = stop-gradient operator: treats $s_y^{(i)}$ as a constant (no gradients flow through the target encoder)
  • $\|\cdot\|_2^2$ = squared L2 norm
  • $f_\theta$ = context encoder with parameters $\theta$ (trainable)
  • $f_{\bar{\theta}}$ = target encoder with EMA parameters $\bar{\theta}$ (not directly trained)
  • $g_\phi$ = predictor with parameters $\phi$ (trainable)
  • $x_{\text{ctx}}$ = context patches (visible subset of input image $x$)
  • $D = 1280$ for ViT-H models

Why L2 loss prevents collapse: A common concern with joint-embedding architectures is representational collapse — the trivial solution where both encoders map all inputs to the same constant vector, achieving zero loss. I-JEPA avoids collapse through the interaction of three mechanisms:

  1. Stop-gradient on the target encoder prevents the target from adapting to match trivial predictions.
  2. EMA updates make the target encoder a slowly evolving function, providing a stable optimization landscape. The target encoder's representations are diverse (different images produce different targets) because it is not optimized to minimize the loss directly.
  3. The narrow predictor bottleneck prevents the system from encoding trivially predictable representations. If the encoder produced constant outputs, the predictor (with limited capacity) could easily predict them — but the target encoder would still produce non-constant outputs (since it is an EMA of a randomly initialized encoder), and the loss would not reach zero.

Empirically, the authors confirm that removing any one of these three components (stop-gradient, EMA, narrow predictor) leads to collapse or significant performance degradation.

Note on loss formulation: The loss is computed per-patch, not per-block. Each of the $M$ target patches contributes equally to the loss, regardless of which target block it belongs to. There is no block-level pooling or averaging before loss computation. This patch-level granularity provides a dense training signal that encourages fine-grained spatial understanding.

5. Implementation Details

The following table presents the complete set of training hyperparameters, drawn from the paper and the public repository at https://github.com/facebookresearch/ijepa.

HyperparameterValueNotes
Encoder architectureViT-H/1432 layers, 1280-dim, 16 heads, patch 14×14
Predictor architectureNarrow ViT12 layers, 384-dim, 12 heads
Image resolution224×224448×448 for some evaluations
Patch grid16×16 = 256 patchesFor ViT-H/14 at 224×224
OptimizerAdamW$\beta_1 = 0.9, \beta_2 = 0.999$
Base learning rate1.5 × 10⁻³Linearly scaled: $\text{lr} = \text{base\_lr} \times \text{batch\_size} / 256$
Weight decay0.04 → 0.4Cosine schedule from 0.04 to 0.4
LR scheduleCosine decayAfter linear warmup; min LR ~ 1 × 10⁻⁶
Warmup epochs15Linear warmup from 0 to peak LR
Total epochs300On ImageNet-1K (1.28M images)
Batch size2048Distributed across GPUs via DDP
EMA momentum0.996 → 1.0Cosine schedule over training
Target blocks4Per image
Target block scale[0.15, 0.2]Fraction of total patches per block
Target block aspect ratio[0.75, 1.5]Roughly square to mildly rectangular
Data augmentationRandomResizedCrop + HFlipNo color jitter, blur, solarize, or multi-crop
Mixed precisionYes (FP16/BF16)PyTorch AMP
Gradient clippingMax norm ~3.0Per codebase defaults
GPU typeNVIDIA A100 (80GB)
GPU count (ViT-H/14)16 A100s~72 GPU-hours total
Training datasetImageNet-1K1,281,167 training images, labels not used

Repository Structure and Key Code References

The official repository at https://github.com/facebookresearch/ijepa is organized as follows:

# Key files and their roles:
# main_distributed.py      — Training entry point (launches distributed training)
# src/train.py             — Core training loop
# src/helper.py            — Model initialization, optimizer setup
# src/models/vision_transformer.py — VisionTransformer and VisionTransformerPredictor
# src/masks/multiblock.py  — MaskCollator: multi-block masking implementation
# src/utils/schedulers.py  — Cosine schedules for LR and EMA momentum
# configs/in1k_vith14_ep300.yaml — ViT-H/14 training configuration

The encoder is implemented as class VisionTransformer in src/models/vision_transformer.py. The predictor is implemented as class VisionTransformerPredictor in the same file. The multi-block masking logic lives in class MaskCollator in src/masks/multiblock.py, which generates mask indices in the dataloader's collate function.

6. Algorithm

Algorithm 1: I-JEPA Training (One Epoch)
Input: Training dataset $\mathcal{D}$ of images; context encoder $f_\theta$; target encoder $f_{\bar{\theta}}$; predictor $g_\phi$; EMA momentum schedule $m(t)$; learning rate schedule $\eta(t)$
Output: Updated parameters $\theta$, $\phi$, $\bar{\theta}$
1 for each mini-batch $\{x_1, \ldots, x_B\} \sim \mathcal{D}$ do
2 for each image $x_b$ in batch do
3 Apply RandomResizedCrop and horizontal flip to $x_b$
4 Patchify $x_b$ into tokens: $\{z_1, \ldots, z_N\}$ where $N = 256$
5 Sample 4 target blocks via Algorithm 2 → target indices $\mathcal{T}_b$, context indices $\mathcal{C}_b = \{1,\ldots,N\} \setminus \mathcal{T}_b$
6 end for
7 // Target encoder forward (no gradient)
8 with no_grad():
9 $S_y = f_{\bar{\theta}}(\{z_1, \ldots, z_N\})$   // Full image through target encoder → B×N×D
10 Extract target representations: $s_y^{(i)} = S_y[:, \mathcal{T}, :]$   // B×M×D
11 // Context encoder forward (with gradient)
12 $h_\text{ctx} = f_\theta(\{z_j : j \in \mathcal{C}\})$   // Context patches only → B×|C|×D
13 // Predictor forward (with gradient)
14 $\hat{h}_\text{ctx} = \text{Linear}_{D \to D_p}(h_\text{ctx})$   // Project to predictor dim → B×|C|×384
15 Create mask tokens $m_i \in \mathbb{R}^{D_p}$ for each $i \in \mathcal{T}$, add positional embeddings
16 $\hat{s}_y = g_\phi([\hat{h}_\text{ctx}; m_1, \ldots, m_M])$   // Predictor processes both → extract target positions
17 $\hat{s}_y = \text{Linear}_{D_p \to D}(\hat{s}_y)$   // Project back → B×M×D
18 // Compute loss
19 $\mathcal{L} = \frac{1}{M}\sum_{i=1}^{M} \|\hat{s}_y^{(i)} - \text{sg}(s_y^{(i)})\|_2^2$
20 // Update trainable parameters
21 $\theta \leftarrow \theta - \eta(t) \cdot \nabla_\theta \mathcal{L}$   // AdamW update for context encoder
22 $\phi \leftarrow \phi - \eta(t) \cdot \nabla_\phi \mathcal{L}$   // AdamW update for predictor
23 // EMA update for target encoder
24 $\bar{\theta} \leftarrow m(t) \cdot \bar{\theta} + (1 - m(t)) \cdot \theta$
25 $t \leftarrow t + 1$
26 end for
Algorithm 2: Multi-Block Mask Sampling
Input: Patch grid dimensions $H_g \times W_g$ (e.g., 16×16); number of target blocks $K=4$; scale range $[s_\text{min}, s_\text{max}] = [0.15, 0.2]$; aspect ratio range $[a_\text{min}, a_\text{max}] = [0.75, 1.5]$
Output: Target patch indices $\mathcal{T}$, context patch indices $\mathcal{C}$
1 $\mathcal{T} \leftarrow \emptyset$
2 $N \leftarrow H_g \times W_g$   // Total patches (e.g., 256)
3 for $k = 1, \ldots, K$ do
4 Sample scale $s_k \sim \text{Uniform}(s_\text{min}, s_\text{max})$
5 Sample aspect ratio $a_k \sim \text{Uniform}(a_\text{min}, a_\text{max})$
6 Compute block dimensions: $n_k = \lfloor s_k \cdot N \rfloor$   // Number of patches in block
7 $h_k = \lfloor \sqrt{n_k \cdot a_k} \rfloor$, $w_k = \lfloor \sqrt{n_k / a_k} \rfloor$   // Block height and width in patches
8 Clip: $h_k = \min(h_k, H_g)$, $w_k = \min(w_k, W_g)$
9 Sample center position: $r_k \sim \text{Uniform}(0, H_g)$, $c_k \sim \text{Uniform}(0, W_g)$
10 Compute block bounds: $r_\text{start} = \max(0, r_k - h_k/2)$, $c_\text{start} = \max(0, c_k - w_k/2)$
11 Clip to grid: $r_\text{end} = \min(r_\text{start} + h_k, H_g)$, $c_\text{end} = \min(c_\text{start} + w_k, W_g)$
12 $\mathcal{T}_k = \{(r, c) : r \in [r_\text{start}, r_\text{end}), c \in [c_\text{start}, c_\text{end})\}$
13 $\mathcal{T} \leftarrow \mathcal{T} \cup \mathcal{T}_k$
14 end for
15 $\mathcal{C} \leftarrow \{1, \ldots, N\} \setminus \mathcal{T}$
16 return $\mathcal{T}$, $\mathcal{C}$
Algorithm 3: I-JEPA Inference (Feature Extraction)
Input: Test image $x$; trained target encoder $f_{\bar{\theta}}$
Output: Feature vector $v \in \mathbb{R}^D$
1 Apply standard resize and center crop to $x$ (224×224 or 448×448)
2 Patchify $x$ into tokens: $\{z_1, \ldots, z_N\}$ with positional embeddings
3 $S = f_{\bar{\theta}}(\{z_1, \ldots, z_N\})$   // Full image through target encoder → N×D
4 $v = \frac{1}{N}\sum_{i=1}^{N} S_i$   // Average-pool all patch tokens → D-dim vector
5 return $v$

7. Training

This section describes exactly what happens during a single I-JEPA training iteration, from raw image to parameter update.

Step-by-Step Training Iteration

  1. Image sampling and augmentation: A mini-batch of $B = 2048$ images is sampled from ImageNet-1K. Each image undergoes only two augmentations: RandomResizedCrop (scale 0.3–1.0, aspect ratio 3/4–4/3, resized to 224×224) and random horizontal flip. No color jitter, Gaussian blur, solarization, or multi-crop is applied.
  2. Patchification: Each 224×224 image is divided into a 16×16 grid of non-overlapping 14×14 pixel patches, yielding $N = 256$ tokens. Each patch is linearly embedded to dimension $D = 1280$ and summed with its learnable positional embedding.
  3. Mask generation: The MaskCollator generates four rectangular target blocks per image (Algorithm 2). The union of these blocks defines the target set $\mathcal{T}$ (typically 180–215 patches); the complement defines the context set $\mathcal{C}$ (typically 40–75 patches).
  4. Target encoder forward pass (no gradient): All $N = 256$ patch tokens (with positional embeddings) are passed through the target encoder $f_{\bar{\theta}}$. This produces target representations $S_y \in \mathbb{R}^{B \times N \times D}$. The representations at target positions are extracted: $s_y \in \mathbb{R}^{B \times M \times D}$, where $M = |\mathcal{T}|$. This entire computation is wrapped in torch.no_grad().
  5. Context encoder forward pass (with gradient): Only the context patches $\{z_j : j \in \mathcal{C}\}$ are passed through the context encoder $f_\theta$. This produces context representations $h_\text{ctx} \in \mathbb{R}^{B \times |\mathcal{C}| \times D}$. Because the encoder processes only 40–75 tokens (not 256), this is approximately $3{-}6\times$ cheaper than a full forward pass.
  6. Predictor forward pass (with gradient): The context representations are projected from $D = 1280$ to $D_p = 384$ via a linear layer. Learnable mask tokens ($D_p = 384$) are created at each target position and combined with positional embeddings encoding their spatial location on the patch grid. The concatenation of projected context tokens and positioned mask tokens is processed through the predictor's 12 transformer blocks. Output tokens at target positions are extracted and projected back to dimension $D = 1280$, yielding predicted representations $\hat{s}_y \in \mathbb{R}^{B \times M \times D}$.
  7. Loss computation: The MSE loss $\mathcal{L} = \frac{1}{M}\sum_{i=1}^{M}\|\hat{s}_y^{(i)} - \text{sg}(s_y^{(i)})\|_2^2$ is computed between predicted and target representations.
  8. Backpropagation: Gradients of $\mathcal{L}$ are computed with respect to $\theta$ (context encoder) and $\phi$ (predictor). AdamW updates both parameter sets. The learning rate follows a cosine schedule with 15-epoch linear warmup; weight decay follows a cosine schedule from 0.04 to 0.4.
  9. EMA update: The target encoder parameters are updated: $\bar{\theta} \leftarrow m(t) \cdot \bar{\theta} + (1 - m(t)) \cdot \theta$, where the momentum $m(t)$ follows a cosine schedule from 0.996 to 1.0.

Training Architecture Diagram

I-JEPA: Detailed Training Pipeline Raw Image 224×224×3 crop+flip only Patchify 14×14 px N=256 tokens Embed + Pos Linear proj B×256×1280 Multi-Block Masking 4 target blocks scale [0.15, 0.2] context patches B×N_c×1280 (N_c≈40-75) all patches B×256×1280 Target Encoder ViT-H/14 (EMA) 32L, D=1280, sg(·) B×256×1280 Extract target positions B×M×1280 sg(s_y) B×M×1280 Context Encoder ViT-H/14 (trainable) 32 layers, D=1280, 16 heads B×N_c×1280 Linear: 1280→384 B×N_c×384 Mask Tokens M×384 + PosEmb Predictor 12L, D_p=384, 12 heads output: B×M×384 Linear: 384→1280 ŝ_y (B×M×1280) L2 Loss ||ŝ_y - sg(s_y)||² ∇ backprop EMA: θ̄ ← mθ̄ + (1-m)θ m: 0.996 → 1.0 (cosine)
Figure 3: Detailed training pipeline showing the complete data flow from raw image through patchification, masking, dual-encoder processing, prediction, and loss computation. Green solid lines indicate the trainable path with gradient flow. Dashed lines indicate the EMA/stop-gradient path. Dimension annotations show tensor shapes at each stage.

Training Objective (Mathematical Formulation)

Formally, the I-JEPA training objective minimizes:

$$\min_{\theta, \phi} \; \mathbb{E}_{x \sim \mathcal{D}} \; \mathbb{E}_{\mathcal{T}, \mathcal{C}} \left[ \frac{1}{|\mathcal{T}|} \sum_{i \in \mathcal{T}} \left\| g_\phi\left(f_\theta(x_\mathcal{C}), \mathbf{p}_i\right) - \text{sg}\left(f_{\bar{\theta}}(x)_i\right) \right\|_2^2 \right]$$

where $x_\mathcal{C}$ denotes the context patches of image $x$, $\mathbf{p}_i$ is the positional encoding for target position $i$, $f_{\bar{\theta}}(x)_i$ is the target encoder's representation at position $i$, and the inner expectation is over the random masking.

8. Inference

At inference time, I-JEPA discards both the predictor and the masking mechanism. The trained target encoder (the EMA encoder $f_{\bar{\theta}}$) serves as a general-purpose feature extractor.

Feature Extraction Pipeline

  1. Preprocessing: The input image is resized (shorter side to 256 pixels) and center-cropped to 224×224 (or 448×448 for high-resolution evaluation).
  2. Patchification: The image is divided into non-overlapping patches (14×14 or 16×16 pixels), producing $N$ tokens.
  3. Encoding: All $N$ patch tokens (with positional embeddings) are passed through the frozen target encoder, yielding patch-level representations $\{s_1, \ldots, s_N\} \in \mathbb{R}^{N \times D}$.
  4. Pooling: Representations are average-pooled across all patches: $v = \frac{1}{N}\sum_{i=1}^{N} s_i \in \mathbb{R}^D$. Notably, I-JEPA does not use a CLS token; global average pooling over patch tokens is the default feature extraction method.
  5. Downstream head: The pooled feature vector $v$ is fed to a task-specific head (e.g., a linear classifier for probing, or a fine-tuned network).

Evaluation Protocols

Linear probing: A single linear layer is trained on top of frozen features. Optimizer: SGD with learning rate sweep (0.01–0.3), cosine schedule, 100 epochs, batch size 16384.

Attentive probing: A single cross-attention layer pools information from all patch tokens (rather than simple averaging). This provides the model a learned way to weight patches differently per-class, improving accuracy by approximately 0.5–1% over average pooling. The attention pooling layer has a small number of trainable parameters while the encoder remains frozen.

Fine-tuning: The encoder can also be fully fine-tuned end-to-end with a classification head for maximum downstream performance, though the paper primarily reports linear and attentive probing results to assess representation quality.

Multi-resolution evaluation: For ViT-H/14 at 448×448, the image produces a 32×32 grid of patches ($N = 1024$). The positional embeddings (learned for a 16×16 grid during training) are bilinearly interpolated to the 32×32 grid. This resolution increase at test time consistently improves accuracy (e.g., 80.9% → 81.6% for ViT-H/14).

I-JEPA Inference Pipeline Test Image resize+crop 224×224×3 Patchify + PosEmbed 256×1280 Target Encoder ViT-H/14 (frozen) ALL patches → 256×1280 No masking at inference Avg Pool over patches 1×1280 Linear Head 1280→K trainable Class Prediction At inference time: ✓ Target encoder (EMA) is used — provides smoothed, stable representations ✗ Predictor is discarded — only needed for training signal ✗ Masking is disabled — all patches are processed (no context/target split)
Figure 4: Inference pipeline. The frozen target encoder processes all image patches. Representations are average-pooled and fed to a lightweight downstream head. The predictor and masking mechanism are not used at inference time.

9. Results & Benchmarks

ImageNet-1K Linear Probing

The primary evaluation metric is top-1 accuracy of a linear classifier trained on frozen features from ImageNet-1K (without labels during pre-training).

MethodArchitectureEpochsResolutionLinear Probe (Top-1)
I-JEPAViT-H/1430022480.9%
I-JEPAViT-H/1430044881.6%
I-JEPAViT-H/1630022479.3%
I-JEPAViT-G/1630022481.6%
MAEViT-H/14160022476.6%
data2vecViT-L/1680022476.6%
DINOViT-B/1630022478.2%
iBOTViT-L/1625022479.5%
MSNViT-L/1660022479.4%

I-JEPA's ViT-H/14 at 80.9% outperforms MAE's ViT-H by 4.3 percentage points despite training for only 300 epochs versus MAE's 1600 — roughly a 5× reduction in training compute. It is competitive with iBOT (79.5%) and MSN (79.4%) despite using no multi-crop augmentation and no color jitter.

Low-Shot Evaluation (1% ImageNet)

MethodArchitecture1% ImageNet (Top-1)
I-JEPAViT-H/14~72%
DINOViT-B/16~68%
MAEViT-H/14~49%

The gap between I-JEPA and MAE in the 1% regime (~23 percentage points) is far larger than the gap in full-data linear probing (~4.3 points). This confirms that I-JEPA's features are more semantically structured: with very few labeled examples, semantic features produce much better classifiers than features that encode low-level texture.

Transfer Learning

DatasetI-JEPA (ViT-H/14)MAE (ViT-H/14)
Places205~60.3%~57.9%
iNaturalist 2018~71.2%~60.5%
CIFAR-100~88.2%~83.8%
Food-101~90.5%~85.7%

I-JEPA shows consistent advantages across all transfer benchmarks. The improvement is particularly pronounced on fine-grained recognition tasks like iNaturalist (+10.7 points), where semantic understanding of species-level differences matters far more than texture reconstruction ability.

Ablation Studies

The paper presents thorough ablations isolating the contribution of each design choice:

Masking Strategy

Masking StrategyLinear Probe (Top-1)Δ vs. Multi-block
Multi-block (4 targets, scale 0.15–0.2)Baseline
Single large block−2 to −3%Spatially localized, less semantic
Random token masking (MAE-style)−3 to −4%Local texture prediction, not global
Large target blocks (scale 0.4+)−1 to −2%Prediction too easy from context
Fewer target blocks (1–2)worse than 4Less diverse spatial prediction

Predictor Architecture

Predictor DepthLinear Probe (Top-1)Δ vs. 12 Layers
12 layersBaseline (optimal)
6 layers−0.5 to −1%Slightly limited capacity
0 layers (linear)−3 to −4%No nonlinear predictive power

Masking Ratio

Effective Visible RatioRelative Performance
15–30% visible (I-JEPA default)Optimal
~50% visible−1 to −2%
~75% visible−3%

This is the opposite of MAE's behavior, where 75% masking is optimal. I-JEPA benefits from more aggressive masking because it operates in representation space: with fewer visible patches, the encoder must produce more compressed, semantic features. In pixel space, excessive masking makes reconstruction intractably difficult; in representation space, the target encoder filters away unpredictable details, making aggressive masking feasible.

EMA and Collapse

Removing the EMA mechanism (using a direct copy of the context encoder as the target, or allowing gradients through the target encoder) leads to representational collapse. The EMA + stop-gradient combination is non-negotiable for training stability.

Computational Efficiency

Compute Efficiency Comparison I-JEPA ViT-H/14 ~72 GPU-hrs (A100) 300 epochs, 80.9% LP MAE ViT-H/14 ~360 GPU-hrs (A100) 1600 epochs, 76.6% LP DINO/iBOT ViT-L/16 ~300+ GPU-hrs (multi-crop) 78-79% LP Why I-JEPA is cheaper: 1. Encoder processes only ~50 tokens (not 256) → ~5× cheaper self-attention 2. No multi-crop (DINO processes 10 crops per image) 3. No pixel decoder (MAE's decoder reconstructs 256×patch_size² pixels)
Figure 5: Approximate GPU-hour comparison for ImageNet pre-training. I-JEPA achieves the highest linear probing accuracy with roughly 5× less compute than MAE and substantially less than multi-crop methods like DINO/iBOT. Bar lengths are proportional to approximate compute cost.

10. Connection to the JEPA Family

I-JEPA occupies a foundational position in the JEPA lineage as the first concrete, working instantiation of the Joint-Embedding Predictive Architecture for vision.

What I-JEPA Borrows

The conceptual framework of I-JEPA originates from the JEPA position paper (LeCun, 2022), which proposed learning representations by predicting in embedding space rather than input space. I-JEPA adopts JEPA's three core architectural commitments: (1) a learned predictor (not fixed similarity), (2) an EMA target encoder for producing stable prediction targets, and (3) asymmetric processing (context sees partial input, target sees full input). The EMA target update mechanism specifically echoes earlier work in BYOL and MoCo, which demonstrated that momentum-updated target networks prevent collapse without negative samples.

What is Genuinely Novel

Key contribution of I-JEPA: The multi-block masking strategy for images, which creates a spatially distributed, semantically demanding prediction task that eliminates the need for hand-crafted data augmentations. By sampling multiple small, separated target blocks (rather than a single contiguous region or random tokens), I-JEPA forces the encoder to develop global scene understanding. Combined with the demonstration that this approach achieves state-of-the-art linear probing with only random crop and horizontal flip (no color jitter, blur, solarization, or multi-crop), I-JEPA establishes that the prediction task itself — not augmentation engineering — is sufficient for learning semantic representations.

The second major contribution is the narrow predictor bottleneck design — using a 384-dimensional transformer predictor with a 1280-dimensional encoder. While prior work (BYOL, SimSiam) used MLP predictors, I-JEPA's use of a transformer predictor with deliberate capacity restriction provides a principled mechanism for encouraging the encoder to form information-rich representations: the predictor's limited bandwidth forces the encoder to do the heavy semantic lifting.

Influence on Later Variants

I-JEPA directly inspired and enabled several subsequent JEPA variants:

  • V-JEPA (Video JEPA): Extended I-JEPA's multi-block masking to spatiotemporal tubes in video, predicting representations of masked space-time regions. V-JEPA demonstrated that the same latent prediction principle scales from static images to temporal dynamics.
  • MC-JEPA: Combined I-JEPA's latent prediction with multi-crop augmentation, exploring whether the two approaches are complementary rather than substitutes.
  • A-JEPA (Audio JEPA): Adapted I-JEPA's architecture to audio spectrograms, replacing spatial patches with time-frequency patches and applying multi-block masking in the spectrogram domain.

I-JEPA also established key design principles that became standard in the JEPA family: (1) the target encoder always sees the complete input; (2) the context encoder processes only unmasked regions for efficiency; (3) the predictor is deliberately capacity-constrained; (4) minimal augmentation is preferred, letting the prediction task drive representation learning.

Relationship to Non-JEPA Methods

I-JEPA can be understood as occupying a unique position between two established paradigms:

  • vs. Masked Autoencoders (MAE, BEiT): Both use masking, but MAE reconstructs pixels/tokens while I-JEPA predicts in learned representation space. This shift from input space to latent space is the defining characteristic of the JEPA family.
  • vs. Contrastive/Invariance methods (DINO, BYOL, SimCLR): Both use joint embeddings, but invariance methods compare different augmented views of the same image. I-JEPA instead compares a predicted representation against a target representation, introducing a predictive (rather than purely contrastive) learning objective. This removes dependence on hand-crafted augmentation invariances.

11. Summary

Key takeaway: I-JEPA demonstrates that predicting abstract representations of masked image regions — rather than reconstructing pixels or comparing augmented views — produces more semantic, more efficient, and more transferable visual representations. A ViT-H/14 trained with I-JEPA for 300 epochs on ImageNet achieves 80.9% linear probing accuracy, outperforming MAE (76.6% at 1600 epochs) with roughly 5× less compute and no reliance on hand-crafted augmentations.

Main contribution: The multi-block masking strategy, which creates a spatially distributed prediction task over four disjoint image regions, combined with a narrow transformer predictor that forces semantic encoding, establishes a new paradigm for self-supervised visual representation learning — one driven by the structure of the prediction task rather than by data augmentation engineering.

When to use I-JEPA vs. alternatives:

  • Use I-JEPA when you need high-quality visual representations with limited compute budget; when your downstream task requires semantic rather than textural features; when you want to avoid designing augmentation pipelines; or when you need strong low-shot performance.
  • Use MAE when you need pixel-level reconstruction for tasks like inpainting or segmentation, or when you plan to fine-tune rather than linearly probe.
  • Use DINO/iBOT when you need features robust to specific augmentation-based invariances (color, blur), or when multi-crop protocols are already part of your pipeline.

12. References

  1. Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023. arXiv:2301.08243.
  2. LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. Technical Report, Meta AI. The foundational JEPA position paper outlining energy-based latent prediction.
  3. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022. arXiv:2111.06377.
  4. Caron, M., Touvron, H., Misra, I., Jégou, H., Mairal, J., Bojanowski, P., & Joulin, A. (2021). Emerging Properties in Self-Supervised Vision Transformers (DINO). ICCV 2021. arXiv:2104.14294.
  5. Zhou, J., Wei, C., Wang, H., Shen, W., Xie, C., Yuille, A., & Kong, T. (2022). iBOT: Image BERT Pre-Training with Online Tokenizer. ICLR 2022. arXiv:2111.07832.
  6. Baevski, A., Hsu, W.-N., Xu, Q., Babu, A., Gu, J., & Auli, M. (2022). data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language. ICML 2022. arXiv:2202.03555.
  7. Assran, M., Caron, M., Misra, I., Bojanowski, P., Bordes, F., Vincent, P., Joulin, A., Rabbat, M., & Ballas, N. (2022). Masked Siamese Networks for Label-Efficient Learning. ECCV 2022. arXiv:2204.07141.
  8. Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., Doersch, C., Pires, B. Á., Guo, Z., Azar, M. G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning (BYOL). NeurIPS 2020. arXiv:2006.07733.
  9. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (ViT). ICLR 2021. arXiv:2010.11929.
  10. Bardes, A., Ponce, J., & LeCun, Y. (2024). Revisiting Feature Prediction for Learning Visual Representations from Video (V-JEPA). arXiv:2404.08471.