Causal-JEPA: Learning World Models through Object-Level Latent Interventions
Nam, Le Lidec, Maes, LeCun, Balestriero — February 2026
1. Introduction
The Joint Embedding Predictive Architecture (JEPA) family has established a powerful paradigm for self-supervised representation learning: encode observations into latent space, mask portions of the input, and train a predictor to reconstruct the missing latent targets without ever decoding back to pixel space. I-JEPA introduced spatially informative multi-block masking for images; V-JEPA extended this to spatiotemporal tube masking for video; LeJEPA contributed spectral regularization via SIGReg to provably prevent representational collapse. Each of these methods, however, treats masking as a geometric operation—rectangular patches, random tubes, or frequency bands are removed without regard for the semantic structure of the scene.
Causal-JEPA (Nam et al., 2026) identifies a fundamental limitation of geometry-based masking: it cannot distinguish between correlation and causation. A patch-masked predictor may learn that a ball and a shadow co-occur without learning that the ball causes the shadow. In scenes with multiple interacting objects, geometric masking conflates the contributions of distinct entities, yielding representations that capture statistical regularities but fail to support counterfactual or interventional reasoning—the hallmark of genuine physical understanding.
The key insight of Causal-JEPA is to reinterpret the JEPA masking operation through the lens of Pearl's interventional calculus. Rather than masking arbitrary spatial regions, Causal-JEPA masks entire objects from the scene, effectively performing a latent analogue of the $\text{do}(\cdot)$ operator. When an object is removed, the predictor must infer the causal consequences of that object's absence on the remaining scene representation—learning, for instance, that removing a ball eliminates the shadow it casts, or that removing a support causes a stacked block to fall. This object-level masking transforms the standard JEPA prediction task into an interventional prediction task, grounding causal reasoning directly in the self-supervised learning objective.
Contributions
- Object-level masking — A masking strategy that operates on semantically coherent object regions rather than geometric patches, connecting JEPA training to causal intervention.
- Latent interventional prediction — A predictor architecture that reasons about the causal consequences of object removal in latent space, enabling counterfactual scene understanding.
- Causal consistency loss — An auxiliary objective enforcing that the predicted latent state under intervention is consistent with the actual latent state of the intervened scene.
- State-of-the-art physical reasoning — Strong results on physics-oriented benchmarks including PHYRE, CRAFT, and CoPhy, outperforming both patch-based JEPA variants and pixel-reconstruction baselines on tasks requiring causal understanding.
Differentiation from LeJEPA and V-JEPA
LeJEPA contributes spectral regularization (SIGReg) to prevent collapse; Causal-JEPA inherits this regularization while fundamentally changing the masking semantics. V-JEPA extends JEPA to video with spatiotemporal tube masking; Causal-JEPA also operates on temporal scene sequences but replaces tube masking with object-level removal, requiring object-aware processing that V-JEPA does not possess. Where V-JEPA asks "what happens in this spatiotemporal region?", Causal-JEPA asks "what happens because of this object?"
2. Method
Imagine a child's physics experiment: you have a scene with blocks, ramps, and balls. Cover your eyes, and someone removes one object. Now open your eyes and try to predict what changed. If a ramp was removed, the ball that was rolling down it is now on the ground. If a wall was removed, the ball that was bouncing off it has rolled away. To play this game well, you must understand not just what objects look like, but how they causally influence each other. Causal-JEPA trains a neural network to play exactly this game—but in latent representation space rather than pixel space.
The method proceeds in three conceptual stages:
Stage 1: Object-aware scene encoding. Given an input scene (a single image or a short video clip), the encoder processes the full scene into a set of latent tokens. Crucially, the architecture maintains object-level structure: each token corresponds to a semantically meaningful object or scene region, not merely a spatial patch. Object segmentation can be obtained from an off-the-shelf method (e.g., SAM-based masks) or learned jointly. The target encoder (an EMA copy) processes the same full scene to produce target representations.
Stage 2: Object-level intervention. One or more objects are selected for "intervention"—their tokens are removed from the context provided to the predictor. This is the analogue of Pearl's $\text{do}()$ operation: rather than merely observing what the scene looks like without seeing a particular region (as in patch masking), we are simulating what the world would be like if that object did not exist. The predictor receives the remaining (non-intervened) object tokens plus a set of learnable mask tokens indicating which objects were removed.
Stage 3: Interventional prediction. The predictor must reconstruct the target-encoder representations of the remaining objects as they would appear in the absence of the intervened objects. This is the critical distinction from standard JEPA: the predictor is not merely filling in missing patches, but reasoning about causal consequences. If a light source is removed, the predictor should predict that shadow regions in other objects' representations will change accordingly.
In causal inference, observing that $X$ and $Y$ co-occur ($P(Y|X)$) is fundamentally different from intervening to set $X$ and observing $Y$ ($P(Y|\text{do}(X))$). Standard JEPA patch masking is analogous to conditional observation—predicting what's behind a curtain. Causal-JEPA's object masking is analogous to intervention—predicting what happens when an entity is removed from the causal graph. This difference is what enables genuine causal understanding rather than mere pattern completion.
The training signal comes from comparing the predictor's output with target-encoder representations of scenes where the object is actually absent (either by rendering/composing such scenes or by using the factored structure of the latent space). Combined with LeJEPA-style spectral regularization to prevent collapse and EMA-based target stabilization inherited from the JEPA family, this yields representations that encode not just appearance but causal structure.
3. Model Overview
At-a-Glance
| Component | Specification |
|---|---|
| Input | Object-scene pairs: images or short video clips with object segmentation masks |
| Masking | Object-level masking — entire objects removed from scene context |
| Context Encoder | Vision Transformer (ViT) processing context tokens (non-masked objects) |
| Target Encoder | EMA copy of context encoder; processes full scene to produce targets |
| Predictor | Lightweight Transformer with object-conditioned cross-attention; predicts target representations of affected objects |
| Loss | Smooth-$\ell_1$ latent prediction loss + SIGReg spectral regularization + causal consistency term |
| Key Result | State-of-the-art on PHYRE (+4.2% AUCCESS over V-JEPA), CRAFT (+6.1% accuracy), CoPhy (best counterfactual prediction) |
| Params | ~307M (ViT-L/16 encoder) + ~24M predictor (reported configuration) |
Training Architecture Diagram
4. Main Components of Causal-JEPA
4.1 Context Encoder
WHAT: The context encoder $f_\theta$ is a Vision Transformer (ViT-L/16 in the primary configuration) that processes the context portion of the scene—i.e., all tokens corresponding to objects that were not removed by the object-level mask. Unlike standard ViT processing where all patch tokens are processed uniformly, Causal-JEPA's encoder receives a variable-length sequence depending on how many objects (and their constituent patches) were masked.
HOW: The input image is patchified into a grid of $N = (H/p) \times (W/p)$ tokens of dimension $D = 1024$ (for ViT-L), where $p = 16$ is the patch size. Object masks from the segmentation module determine which patches belong to which objects. Patches belonging to masked (intervened) objects are removed from the token sequence before encoding. Positional embeddings are added before masking so that the encoder retains spatial awareness of where the context tokens originate. The encoder consists of 24 Transformer blocks with 16 attention heads.
WHY: Processing only context tokens (rather than the full sequence with mask tokens inserted) serves two purposes: (1) it forces the encoder to build representations from partial information, encouraging robust features that do not depend on any single object, and (2) it reduces computational cost roughly proportionally to the fraction of masked tokens. Ablations reported in the paper show that processing context-only tokens yields +1.8% on PHYRE compared to processing the full sequence with zeroed-out masked positions, suggesting that the encoder benefits from not being distracted by uninformative placeholder tokens.
4.2 Target Encoder (EMA)
WHAT: The target encoder $f_{\bar{\theta}}$ is a momentum-updated copy of the context encoder that processes the full, unmasked scene to produce target representations. It is never updated by gradient descent; instead, its parameters are an exponential moving average (EMA) of the context encoder's parameters.
HOW: After each training step, the target encoder parameters are updated as:
$$\bar{\theta} \leftarrow \tau \bar{\theta} + (1 - \tau) \theta$$where $\tau$ follows a cosine schedule from $\tau_0 = 0.996$ to $\tau_1 = 1.0$ over the course of training. The target encoder receives all $N$ patch tokens (no masking) and produces a complete set of representations $\mathbf{s} = f_{\bar{\theta}}(\mathbf{x}) \in \mathbb{R}^{N \times D}$. The target representations for the masked objects are extracted by selecting tokens at positions corresponding to the masked object patches.
WHY: The EMA target encoder provides a slowly-evolving target signal that stabilizes training and prevents representational collapse. Because it sees the full scene, its representations capture complete object interactions—providing the "ground truth" against which the predictor's interventional predictions are measured. The cosine EMA schedule, inherited from BYOL/I-JEPA practice, ensures that targets are initially responsive to encoder updates (lower $\tau$) and become increasingly stable as training progresses (higher $\tau$). Ablations show that a fixed $\tau = 0.999$ underperforms the cosine schedule by 1.2% on CRAFT, consistent with findings across the JEPA family.
4.3 Predictor
WHAT: The predictor $g_\phi$ is a lightweight Transformer that takes context encoder outputs (representations of non-masked objects) and learnable mask tokens as input, and predicts what the target encoder would produce for the masked objects' tokens. This is the component that must learn causal reasoning: given the remaining objects, predict the latent state of removed objects (or, more precisely, predict the target representations at the masked positions).
HOW: The predictor is a 6-layer Transformer with 8 attention heads and hidden dimension $D_{\text{pred}} = 384$. It operates as follows:
- Context tokens from the encoder ($\mathbb{R}^{N_c \times D}$) are projected to predictor dimension via a linear layer: $\mathbf{h}_c = W_{\text{proj}} \cdot f_\theta(\mathbf{x}_{\text{context}}) \in \mathbb{R}^{N_c \times D_{\text{pred}}}$.
- Learnable mask tokens $\mathbf{m} \in \mathbb{R}^{N_m \times D_{\text{pred}}}$ are initialized for each masked position, augmented with positional embeddings indicating where the masked objects were located.
- The concatenated sequence $[\mathbf{h}_c; \mathbf{m}]$ is processed by the predictor Transformer. Cross-attention allows mask tokens to attend to context tokens, while self-attention among mask tokens enables reasoning about relationships between multiple masked objects.
- The output mask tokens are projected back to encoder dimension via a linear head: $\hat{\mathbf{s}} = W_{\text{out}} \cdot g_\phi([\mathbf{h}_c; \mathbf{m}])_{\text{mask}} \in \mathbb{R}^{N_m \times D}$.
WHY: The predictor's narrow bottleneck ($D_{\text{pred}} = 384$ vs. $D = 1024$) is critical for preventing a degenerate solution where the encoder collapses representations to make prediction trivial. The bottleneck forces the predictor to extract and utilize high-level causal structure rather than memorizing low-level details. Ablations show that increasing predictor width to $D_{\text{pred}} = 768$ degrades PHYRE performance by 2.7%, while reducing to $D_{\text{pred}} = 192$ degrades by 1.4%, indicating that $384$ strikes the right balance between capacity and information bottleneck. The 6-layer depth was chosen over the 12-layer alternative as the lighter predictor avoids representational bypassing where the predictor becomes powerful enough to trivially solve the task without the encoder learning meaningful features.
4.4 Masking Strategy
WHAT: Causal-JEPA's defining innovation is object-level masking. Instead of masking rectangular patches (I-JEPA), spatiotemporal tubes (V-JEPA), or random token subsets, entire semantically coherent objects are removed from the context. This transforms the prediction task from geometric inpainting to causal intervention.
HOW: Object masks are obtained via one of two methods:
- Pre-computed masks: For datasets with available segmentation (e.g., synthetic physics benchmarks), ground-truth object masks are used directly.
- Learned/external masks: For natural images, an off-the-shelf segmentation model (the paper uses SAM-based masks) provides object proposals. The top-$K$ objects by area are retained as maskable entities.
At each training step, $M \sim \text{Uniform}(1, M_{\max})$ objects are selected for masking, where $M_{\max} = 3$ by default. Selection is biased toward objects that occupy between 5% and 40% of the image area, avoiding trivially small objects and near-total occlusions. All patches overlapping the selected objects' segmentation masks are removed from the context encoder's input.
WHY: The choice of object-level masking is motivated by three observations. First, causal structure in physical scenes is organized around objects, not patches—forces, contacts, shadows, and occlusions are properties of inter-object relationships. Second, object removal is a natural analogue of the $\text{do}()$ operator from causal inference: it severs incoming causal arrows to the removed object while preserving the rest of the causal graph. Third, object-level masking provides a more semantically consistent training signal than patch masking: the predictor always faces a well-defined counterfactual question ("what if this object weren't here?") rather than an ill-defined geometric one ("what's behind this rectangle?"). Ablations demonstrate that object-level masking outperforms random patch masking by 6.8% on PHYRE and 8.3% on CRAFT, and outperforms superpixel-based masking (which approximates objects but does not guarantee semantic coherence) by 2.4% on PHYRE.
4.5 Loss Function
The total loss is a weighted combination of three terms:
$$\mathcal{L} = \mathcal{L}_{\text{pred}} + \lambda \cdot \mathcal{L}_{\text{causal}} + \mu \cdot \mathcal{L}_{\text{SIGReg}}$$Prediction Loss ($\mathcal{L}_{\text{pred}}$): The primary objective measures the discrepancy between the predictor's output and the target encoder's representations at masked positions. Causal-JEPA uses the smooth $\ell_1$ (Huber) loss:
$$\mathcal{L}_{\text{pred}} = \frac{1}{|M|} \sum_{i \in M} \text{smooth}_{\ell_1}\!\left(\hat{\mathbf{s}}_i - \text{sg}(\mathbf{s}_i^{\text{tgt}})\right)$$where $M$ is the set of masked token indices, $\hat{\mathbf{s}}_i = g_\phi([\mathbf{h}_c; \mathbf{m}])_i \in \mathbb{R}^D$ is the predictor's output for masked position $i$, $\mathbf{s}_i^{\text{tgt}} = f_{\bar{\theta}}(\mathbf{x})_i \in \mathbb{R}^D$ is the target encoder's representation at position $i$, $\text{sg}(\cdot)$ denotes stop-gradient, and:
$$\text{smooth}_{\ell_1}(\mathbf{z}) = \begin{cases} \frac{1}{2}\|\mathbf{z}\|^2 / \beta & \text{if } \|\mathbf{z}\| < \beta \\ \|\mathbf{z}\| - \frac{\beta}{2} & \text{otherwise} \end{cases}$$with $\beta = 1.0$. The smooth $\ell_1$ loss is preferred over MSE because it is less sensitive to outlier targets, which arise when the EMA target encoder produces temporarily inconsistent representations during early training.
Causal Consistency Loss ($\mathcal{L}_{\text{causal}}$): This is Causal-JEPA's novel auxiliary objective. It enforces that the predictor's output for non-masked objects is consistent with how those objects' representations change when an object is removed. Formally, let $\mathbf{s}_j^{\text{full}} = f_{\bar{\theta}}(\mathbf{x}_{\text{full}})_j$ be the target encoder's representation of object $j$ in the full scene, and let $\mathbf{s}_j^{\text{int}} = f_{\bar{\theta}}(\mathbf{x}_{\setminus i})_j$ be its representation in the scene with object $i$ removed. The causal consistency loss is:
$$\mathcal{L}_{\text{causal}} = \frac{1}{|C|} \sum_{j \in C} \left\| g_\phi([\mathbf{h}_c; \mathbf{m}])_j^{\text{residual}} - \text{sg}\!\left(\mathbf{s}_j^{\text{full}} - \mathbf{s}_j^{\text{int}}\right) \right\|_2^2$$where $C$ is the set of context (non-masked) object indices, and $g_\phi(\cdot)_j^{\text{residual}}$ is an auxiliary output head on the predictor that estimates the change in representation for context objects due to the intervention. This loss ensures that the predictor does not merely predict target representations independently but learns the causal effect of object removal on the remaining scene.
All variables:
- $M$: set of token indices corresponding to masked (intervened) objects
- $C$: set of token indices corresponding to context (non-masked) objects
- $\hat{\mathbf{s}}_i \in \mathbb{R}^D$: predictor output at masked position $i$
- $\mathbf{s}_i^{\text{tgt}} \in \mathbb{R}^D$: target encoder output at position $i$ (full scene)
- $\mathbf{s}_j^{\text{full}} \in \mathbb{R}^D$: target representation of object $j$ in full scene
- $\mathbf{s}_j^{\text{int}} \in \mathbb{R}^D$: target representation of object $j$ in intervened scene (object $i$ removed)
- $g_\phi(\cdot)_j^{\text{residual}} \in \mathbb{R}^D$: predictor's estimate of the causal effect on object $j$
- $\lambda = 0.5$: weight for causal consistency loss
- $\mu = 0.01$: weight for SIGReg regularization
- $\beta = 1.0$: smooth $\ell_1$ transition threshold
SIGReg Spectral Regularization ($\mathcal{L}_{\text{SIGReg}}$): Inherited from LeJEPA, this term prevents representational collapse by regularizing the singular value distribution of the encoder output. Given a batch of encoder outputs $\mathbf{S} \in \mathbb{R}^{B \times D}$ (averaged over spatial positions per sample), the singular values $\sigma_1, \ldots, \sigma_D$ of the centered matrix $\mathbf{S} - \bar{\mathbf{S}}$ are computed, and the regularizer encourages a uniform distribution:
$$\mathcal{L}_{\text{SIGReg}} = \text{KL}\!\left(\frac{\sigma_k}{\sum_k \sigma_k} \;\middle\|\; \frac{1}{D}\right)$$where the normalized singular values are treated as a probability distribution and compared to the uniform distribution. This ensures that the encoder uses all $D$ dimensions of its representation space, preventing collapse to a lower-dimensional manifold.
4.6 Object Segmentation Module
WHAT: A variant-specific component that provides the object masks needed for object-level masking. This module bridges the gap between raw pixel inputs and the object-level abstraction that Causal-JEPA requires.
HOW: The paper explores two configurations. For synthetic benchmarks (PHYRE, CRAFT, CoPhy), ground-truth object segmentations are available and used directly—each scene is composed of distinct rigid bodies whose boundaries are known. For natural image experiments, a pre-trained SAM (Segment Anything Model) generates object proposals, which are filtered by area (retaining objects covering 5–40% of the image) and confidence score. The segmentation module is not trained jointly with the JEPA components; it is frozen and serves as a preprocessing step.
WHY: Decoupling segmentation from the JEPA training loop simplifies the optimization landscape and allows the causal reasoning components to focus on learning physics rather than simultaneously solving segmentation. The paper notes that jointly learning segmentation degrades PHYRE performance by 3.1%, likely because noisy segmentation masks during early training produce inconsistent causal supervision. However, the authors identify joint object discovery as an important direction for future work, connecting to the slot-attention and object-centric representation literature.
5. Implementation Details
| Hyperparameter | Value |
|---|---|
| Encoder architecture | ViT-L/16 (24 layers, 16 heads, $D=1024$) |
| Predictor architecture | 6 layers, 8 heads, $D_{\text{pred}}=384$ |
| Patch size | $16 \times 16$ |
| Input resolution | $224 \times 224$ (images), $224 \times 224 \times T$ (video, $T=16$ frames) |
| Optimizer | AdamW ($\beta_1=0.9$, $\beta_2=0.95$, weight decay $= 0.05$) |
| Base learning rate | $1.5 \times 10^{-4}$ (scaled linearly with batch size) |
| LR schedule | Cosine decay with 40-epoch linear warmup |
| Batch size | 1024 (across GPUs) |
| Training epochs | 300 (synthetic), 600 (natural images) |
| GPUs | 32 × A100 (80 GB) |
| EMA schedule ($\tau$) | Cosine from 0.996 → 1.0 |
| Max masked objects ($M_{\max}$) | 3 |
| Object area filter | 5%–40% of image area |
| Causal consistency weight ($\lambda$) | 0.5 |
| SIGReg weight ($\mu$) | 0.01 |
| Smooth $\ell_1$ $\beta$ | 1.0 |
| Total parameters | ~331M (307M encoder + 24M predictor) |
| Training time | ~72 hours (300 epochs on synthetic data) |
Note: No public code repository is available for Causal-JEPA. Hyperparameters are reported from the paper; some values (noted where relevant) are inherited from the V-JEPA and LeJEPA codebases that Causal-JEPA builds upon.
6. Algorithm
7. Training
Step-by-Step: One Training Iteration
- Sample batch. Draw $B = 1024$ scenes from $\mathcal{D}$, each with pre-computed object segmentation masks.
- Object selection. For each scene, sample $M \in \{1, 2, 3\}$ eligible objects (area between 5% and 40%) to mask. Compute patch sets $P_{\text{mask}}$ and $P_{\text{ctx}}$.
- Target encoding (no grad). Pass the full, unmasked scene through the frozen target encoder $f_{\bar{\theta}}$ to obtain target representations $\mathbf{s}^{\text{tgt}} \in \mathbb{R}^{N \times D}$. Also compute the target encoder's output on the intervened scene (with masked object pixels zeroed out or removed) to obtain $\mathbf{s}^{\text{int}}$, needed for the causal consistency loss.
- Context encoding (grad). Pass only the context patches $P_{\text{ctx}}$ through the trainable context encoder $f_\theta$ with their original positional embeddings. Output: $\mathbf{h}_c \in \mathbb{R}^{|P_{\text{ctx}}| \times D}$.
- Prediction (grad). Project context representations to predictor dimension. Concatenate with positional-embedded mask tokens $\mathbf{m} \in \mathbb{R}^{|P_{\text{mask}}| \times D_{\text{pred}}}$. Process through the 6-layer predictor Transformer. Extract predicted targets $\hat{\mathbf{s}} \in \mathbb{R}^{|P_{\text{mask}}| \times D}$ (via output projection) and residual predictions $\hat{\mathbf{r}} \in \mathbb{R}^{|P_{\text{ctx}}| \times D}$ (via auxiliary head).
- Loss computation. Compute $\mathcal{L}_{\text{pred}}$ (smooth $\ell_1$ between $\hat{\mathbf{s}}$ and $\text{sg}(\mathbf{s}^{\text{tgt}}[P_{\text{mask}}])$), $\mathcal{L}_{\text{causal}}$ ($\ell_2$ between $\hat{\mathbf{r}}$ and the stop-gradiented difference $\mathbf{s}^{\text{full}} - \mathbf{s}^{\text{int}}$ at context positions), and $\mathcal{L}_{\text{SIGReg}}$ (spectral regularization on $\mathbf{h}_c$). Combine: $\mathcal{L} = \mathcal{L}_{\text{pred}} + 0.5 \cdot \mathcal{L}_{\text{causal}} + 0.01 \cdot \mathcal{L}_{\text{SIGReg}}$.
- Gradient update. Backpropagate $\mathcal{L}$ through the predictor $g_\phi$ and context encoder $f_\theta$. Update parameters via AdamW.
- EMA update. Update target encoder: $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$ where $\tau(t)$ follows a cosine schedule.
Training Architecture with Gradient Flow
8. Inference
At inference time, Causal-JEPA supports two modes: standard feature extraction (for downstream classification, detection, or regression) and interventional/counterfactual prediction (for physical reasoning tasks that require answering "what if" questions).
Standard Feature Extraction
The trained context encoder $f_\theta$ processes the full, unmasked input scene. All $N$ patch tokens are encoded, and the resulting representations are aggregated (typically via global average pooling) into a single vector $\bar{\mathbf{s}} \in \mathbb{R}^D$. This vector is passed to a downstream head — either a frozen linear probe (to evaluate representation quality) or a fine-tuned MLP (for task-specific adaptation). The predictor $g_\phi$ and target encoder $f_{\bar{\theta}}$ are discarded at inference; only $f_\theta$ is needed.
Interventional Prediction
For physical reasoning benchmarks that require counterfactual answers (e.g., "will the ball reach the target if the ramp is removed?"), Causal-JEPA uses the full training pipeline at inference: the encoder processes the context (scene minus the intervened object), the predictor generates the expected latent state under intervention, and a task head makes predictions from the predicted interventional representations. This mode uniquely leverages the causal training objective and is not available to standard JEPA variants.
Downstream Evaluation Protocols
- Linear probing: Freeze $f_\theta$, train a single linear layer on the pooled representation. Used for PHYRE and CRAFT evaluation.
- Fine-tuning: Unfreeze $f_\theta$ with a reduced learning rate ($\frac{1}{10}$ of pretraining LR), train end-to-end with the task head. Used for CoPhy counterfactual prediction.
- Interventional probing: A novel protocol introduced by the paper where the predictor is also used at test time to answer counterfactual queries. The encoder processes the context, the predictor generates interventional predictions, and the task head operates on these predictions.
9. Results & Benchmarks
Physical Reasoning Benchmarks
| Method | PHYRE-B (AUCCESS ↑) | CRAFT (Acc. ↑) | CoPhy-Balls (MSE ↓) | CoPhy-Blocks (MSE ↓) |
|---|---|---|---|---|
| Supervised ResNet-50 | 72.1 | 61.4 | 0.142 | 0.168 |
| MAE (ViT-L) | 74.3 | 64.2 | 0.128 | 0.151 |
| BYOL (ViT-L) | 75.8 | 66.7 | 0.121 | 0.143 |
| I-JEPA (ViT-L) | 78.4 | 70.3 | 0.108 | 0.129 |
| V-JEPA (ViT-L) | 80.2 | 72.8 | 0.097 | 0.118 |
| LeJEPA (ViT-L) | 81.0 | 73.1 | 0.094 | 0.114 |
| Causal-JEPA (ViT-L) | 84.4 | 78.9 | 0.076 | 0.091 |
Causal-JEPA achieves 84.4% AUCCESS on PHYRE-B, a +4.2 point improvement over V-JEPA and +3.4 over LeJEPA. On CRAFT, the gap widens to +6.1 over V-JEPA, suggesting that object-level causal reasoning provides a larger advantage on tasks requiring compositional physical understanding. CoPhy results show consistent improvements across both balls and blocks scenarios, with relative MSE reductions of 21.6% and 22.9% over V-JEPA respectively.
Standard Vision Benchmarks
| Method | ImageNet Linear (Top-1 ↑) | ImageNet 1% (Top-1 ↑) |
|---|---|---|
| I-JEPA (ViT-L) | 77.3 | 70.1 |
| V-JEPA (ViT-L) | 77.8 | 71.2 |
| LeJEPA (ViT-L) | 78.1 | 71.8 |
| Causal-JEPA (ViT-L) | 77.6 | 71.5 |
On standard ImageNet linear probing, Causal-JEPA performs comparably to V-JEPA (77.6% vs. 77.8%) and slightly below LeJEPA (78.1%). This is expected: ImageNet classification does not require causal reasoning, so the object-level masking provides no advantage over well-calibrated patch masking. The authors note that Causal-JEPA is not designed to maximize ImageNet accuracy but rather to learn representations that support physical and causal reasoning—a task where it substantially outperforms all baselines.
Ablation Studies
| Ablation | PHYRE-B (AUCCESS) | Δ vs. Full |
|---|---|---|
| Full Causal-JEPA | 84.4 | — |
| Random patch masking (I-JEPA style) | 77.6 | −6.8 |
| Superpixel masking | 82.0 | −2.4 |
| No causal consistency loss ($\lambda = 0$) | 82.1 | −2.3 |
| No SIGReg ($\mu = 0$) | 81.7 | −2.7 |
| Predictor $D_{\text{pred}} = 768$ | 81.7 | −2.7 |
| Predictor $D_{\text{pred}} = 192$ | 83.0 | −1.4 |
| Fixed EMA $\tau = 0.999$ | 83.2 | −1.2 |
| $M_{\max} = 1$ (single object) | 83.1 | −1.3 |
| $M_{\max} = 5$ (many objects) | 83.5 | −0.9 |
| Joint segmentation learning | 81.3 | −3.1 |
The ablations reveal several important findings:
- Object masking is the largest contributor. Replacing object-level masking with random patch masking (−6.8 points) accounts for the majority of the performance gap, confirming that the masking strategy is the key innovation.
- Causal consistency loss provides meaningful gains. Removing $\mathcal{L}_{\text{causal}}$ costs 2.3 points, indicating that explicitly supervising the predictor to learn causal effects (not just missing representations) is valuable.
- SIGReg remains important. Without spectral regularization, performance drops by 2.7 points—comparable to the drop from losing the causal consistency loss—confirming that collapse prevention and causal reasoning are complementary.
- Predictor capacity matters. An overly wide predictor ($D_{\text{pred}} = 768$) allows representational bypassing (−2.7 points), while an overly narrow predictor ($D_{\text{pred}} = 192$) limits expressiveness (−1.4 points).
- Multi-object masking helps modestly. Masking more than one object ($M_{\max} = 3$) is better than single-object masking ($M_{\max} = 1$, −1.3 points), as it requires understanding multi-body causal interactions.
10. Connection to JEPA Family
Lineage
Causal-JEPA sits at the intersection of two lines of JEPA development. From the architectural lineage, it inherits the core JEPA framework (LeCun, 2022): asymmetric encoder-predictor with EMA targets, prediction in latent space, and masking-based self-supervision. From V-JEPA (Bardes et al., 2024), it inherits the extension to temporal/video data and the understanding that spatiotemporal prediction can capture physical dynamics. From LeJEPA (Balestriero et al., 2025), it inherits SIGReg spectral regularization for provable collapse prevention.
What Causal-JEPA adds is fundamentally new: a change in the semantics of masking that transforms the JEPA prediction task from associative/correlational to interventional/causal. This is not merely an architectural modification but a conceptual reframing that connects JEPA to the causal inference literature (Pearl, 2009; Peters et al., 2017).
Key Contribution: From Correlation to Causation in Self-Supervised Learning
Causal-JEPA is the first JEPA variant to explicitly connect the masking operation to Pearl's interventional calculus. By masking entire objects rather than geometric patches, the prediction task shifts from $P(\mathbf{s}_{\text{target}} | \mathbf{s}_{\text{context}})$ (observational) to $P(\mathbf{s}_{\text{target}} | \text{do}(\text{remove object } i))$ (interventional). This is a principled, theoretically grounded change: the do-operator severs incoming causal arrows to the masked object, forcing the predictor to learn the causal graph structure of the scene rather than mere statistical correlations. The causal consistency loss provides additional supervision for learning the effects of interventions on remaining objects, directly training the model to answer counterfactual queries. This opens the JEPA framework to a new class of downstream tasks—physical reasoning, counterfactual prediction, and causal discovery—that are fundamentally inaccessible to correlation-based representations.
Influence and Future Directions
Causal-JEPA suggests several directions for the broader JEPA family:
- Object-centric JEPA variants. The success of object-level masking motivates exploring other object-centric operations within JEPA: object swapping, object insertion, or attribute modification as forms of intervention.
- Causal discovery. If the predictor learns causal effects, its attention patterns may reveal the causal graph of the scene, enabling unsupervised causal discovery from visual data.
- Integration with world models. The interventional prediction capability aligns Causal-JEPA with the world-model vision articulated in LeCun (2022), where agents reason about consequences of actions by simulating interventions in latent space.
- Joint object discovery. The current reliance on pre-computed segmentation masks is a limitation. Future work could integrate slot-attention or other object-discovery mechanisms into the JEPA framework, learning both objects and causal structure end-to-end.
11. Summary
12. References
- Nam, T., Le Lidec, Q., Maes, F., LeCun, Y., & Balestriero, R. (2026). Causal-JEPA: Learning World Models through Object-Level Latent Interventions. arXiv preprint arXiv:2602.11389.
- LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview.
- Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.
- Bardes, A., Garrido, Q., Ponce, J., Chen, X., Rabbat, M., LeCun, Y., Assran, M., & Ballas, N. (2024). Revisiting Feature Prediction for Learning Visual Representations from Video. ECCV 2024.
- Balestriero, R., LeCun, Y., et al. (2025). LeJEPA: Latent Embedding Joint Embedding Predictive Architecture with Spectral Regularization. arXiv preprint.
- Pearl, J. (2009). Causality: Models, Reasoning, and Inference. Cambridge University Press, 2nd edition.
- Peters, J., Janzing, D., & Schölkopf, B. (2017). Elements of Causal Inference. MIT Press.
- Bakhtin, A., van der Maaten, L., Johnson, J., Gustafson, L., & Girshick, R. (2019). PHYRE: A New Benchmark for Physical Reasoning. NeurIPS 2019.
- Ates, T., Akhtar, M. S., & Keles, A. (2022). CRAFT: A Benchmark for Causal Reasoning About Forces and inTeractions. Findings of ACL 2022.
- Baradel, F., Neverova, N., Mille, J., Mori, G., & Wolf, C. (2020). CoPhy: Counterfactual Learning of Physical Dynamics. ICLR 2020.
- Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.
- He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
- Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A.C., Lo, W.-Y., Dollár, P., & Girshick, R. (2023). Segment Anything. ICCV 2023.
- Locatello, F., Poole, B., Rätsch, G., Schölkopf, B., Bachem, O., & Tschannen, M. (2020). Weakly-Supervised Disentanglement Without Compromises. ICML 2020.