AuthorsNam, Le Lidec, Maes, LeCun, Balestriero
Date2026-02
CategoryPhysics / World Models
Derives fromLeJEPA, V-JEPA

Causal-JEPA: Learning World Models through Object-Level Latent Interventions

Nam, Le Lidec, Maes, LeCun, Balestriero — February 2026

1. Introduction

The Joint Embedding Predictive Architecture (JEPA) family has established a powerful paradigm for self-supervised representation learning: encode observations into latent space, mask portions of the input, and train a predictor to reconstruct the missing latent targets without ever decoding back to pixel space. I-JEPA introduced spatially informative multi-block masking for images; V-JEPA extended this to spatiotemporal tube masking for video; LeJEPA contributed spectral regularization via SIGReg to provably prevent representational collapse. Each of these methods, however, treats masking as a geometric operation—rectangular patches, random tubes, or frequency bands are removed without regard for the semantic structure of the scene.

Causal-JEPA (Nam et al., 2026) identifies a fundamental limitation of geometry-based masking: it cannot distinguish between correlation and causation. A patch-masked predictor may learn that a ball and a shadow co-occur without learning that the ball causes the shadow. In scenes with multiple interacting objects, geometric masking conflates the contributions of distinct entities, yielding representations that capture statistical regularities but fail to support counterfactual or interventional reasoning—the hallmark of genuine physical understanding.

The key insight of Causal-JEPA is to reinterpret the JEPA masking operation through the lens of Pearl's interventional calculus. Rather than masking arbitrary spatial regions, Causal-JEPA masks entire objects from the scene, effectively performing a latent analogue of the $\text{do}(\cdot)$ operator. When an object is removed, the predictor must infer the causal consequences of that object's absence on the remaining scene representation—learning, for instance, that removing a ball eliminates the shadow it casts, or that removing a support causes a stacked block to fall. This object-level masking transforms the standard JEPA prediction task into an interventional prediction task, grounding causal reasoning directly in the self-supervised learning objective.

Contributions

  • Object-level masking — A masking strategy that operates on semantically coherent object regions rather than geometric patches, connecting JEPA training to causal intervention.
  • Latent interventional prediction — A predictor architecture that reasons about the causal consequences of object removal in latent space, enabling counterfactual scene understanding.
  • Causal consistency loss — An auxiliary objective enforcing that the predicted latent state under intervention is consistent with the actual latent state of the intervened scene.
  • State-of-the-art physical reasoning — Strong results on physics-oriented benchmarks including PHYRE, CRAFT, and CoPhy, outperforming both patch-based JEPA variants and pixel-reconstruction baselines on tasks requiring causal understanding.

Differentiation from LeJEPA and V-JEPA

LeJEPA contributes spectral regularization (SIGReg) to prevent collapse; Causal-JEPA inherits this regularization while fundamentally changing the masking semantics. V-JEPA extends JEPA to video with spatiotemporal tube masking; Causal-JEPA also operates on temporal scene sequences but replaces tube masking with object-level removal, requiring object-aware processing that V-JEPA does not possess. Where V-JEPA asks "what happens in this spatiotemporal region?", Causal-JEPA asks "what happens because of this object?"

2. Method

Intuition: The "remove and predict" game.
Imagine a child's physics experiment: you have a scene with blocks, ramps, and balls. Cover your eyes, and someone removes one object. Now open your eyes and try to predict what changed. If a ramp was removed, the ball that was rolling down it is now on the ground. If a wall was removed, the ball that was bouncing off it has rolled away. To play this game well, you must understand not just what objects look like, but how they causally influence each other. Causal-JEPA trains a neural network to play exactly this game—but in latent representation space rather than pixel space.

The method proceeds in three conceptual stages:

Stage 1: Object-aware scene encoding. Given an input scene (a single image or a short video clip), the encoder processes the full scene into a set of latent tokens. Crucially, the architecture maintains object-level structure: each token corresponds to a semantically meaningful object or scene region, not merely a spatial patch. Object segmentation can be obtained from an off-the-shelf method (e.g., SAM-based masks) or learned jointly. The target encoder (an EMA copy) processes the same full scene to produce target representations.

Stage 2: Object-level intervention. One or more objects are selected for "intervention"—their tokens are removed from the context provided to the predictor. This is the analogue of Pearl's $\text{do}()$ operation: rather than merely observing what the scene looks like without seeing a particular region (as in patch masking), we are simulating what the world would be like if that object did not exist. The predictor receives the remaining (non-intervened) object tokens plus a set of learnable mask tokens indicating which objects were removed.

Stage 3: Interventional prediction. The predictor must reconstruct the target-encoder representations of the remaining objects as they would appear in the absence of the intervened objects. This is the critical distinction from standard JEPA: the predictor is not merely filling in missing patches, but reasoning about causal consequences. If a light source is removed, the predictor should predict that shadow regions in other objects' representations will change accordingly.

Analogy: Patch masking is observation; object masking is intervention.
In causal inference, observing that $X$ and $Y$ co-occur ($P(Y|X)$) is fundamentally different from intervening to set $X$ and observing $Y$ ($P(Y|\text{do}(X))$). Standard JEPA patch masking is analogous to conditional observation—predicting what's behind a curtain. Causal-JEPA's object masking is analogous to intervention—predicting what happens when an entity is removed from the causal graph. This difference is what enables genuine causal understanding rather than mere pattern completion.

The training signal comes from comparing the predictor's output with target-encoder representations of scenes where the object is actually absent (either by rendering/composing such scenes or by using the factored structure of the latent space). Combined with LeJEPA-style spectral regularization to prevent collapse and EMA-based target stabilization inherited from the JEPA family, this yields representations that encode not just appearance but causal structure.

3. Model Overview

At-a-Glance

ComponentSpecification
InputObject-scene pairs: images or short video clips with object segmentation masks
MaskingObject-level masking — entire objects removed from scene context
Context EncoderVision Transformer (ViT) processing context tokens (non-masked objects)
Target EncoderEMA copy of context encoder; processes full scene to produce targets
PredictorLightweight Transformer with object-conditioned cross-attention; predicts target representations of affected objects
LossSmooth-$\ell_1$ latent prediction loss + SIGReg spectral regularization + causal consistency term
Key ResultState-of-the-art on PHYRE (+4.2% AUCCESS over V-JEPA), CRAFT (+6.1% accuracy), CoPhy (best counterfactual prediction)
Params~307M (ViT-L/16 encoder) + ~24M predictor (reported configuration)

Training Architecture Diagram

Causal-JEPA Training Architecture Input Scene Objects + Layout B×H×W×3 Object Masks B×K binary masks Object Masking do(remove obj_i) Select 1–3 objects Context Encoder ViT-L/16 (trainable) context tokens Target Encoder EMA (frozen) full scene → B×N×D full scene (no masking) Predictor Object-Conditioned Cross-Attn Transformer B×N_c×D Mask Tokens B×N_m×D Predicted Targets ŝ_target: B×N_m×D Target Repr. s_target: B×N_m×D Loss L_pred + λ·L_causal + μ·L_SIGReg smooth-ℓ₁ + consistency + spectral EMA update ━━ trainable ╌╌ frozen (EMA / sg)
Figure 1. Causal-JEPA training architecture. The full scene is processed by the frozen target encoder (EMA) to produce target representations. An object-level mask removes selected objects from the context, which is processed by the trainable context encoder. The predictor receives context embeddings plus learnable mask tokens and predicts the target representations of masked objects. Loss combines prediction error, causal consistency, and SIGReg spectral regularization. Gradients flow only through the context encoder and predictor (solid borders).

4. Main Components of Causal-JEPA

4.1 Context Encoder

WHAT: The context encoder $f_\theta$ is a Vision Transformer (ViT-L/16 in the primary configuration) that processes the context portion of the scene—i.e., all tokens corresponding to objects that were not removed by the object-level mask. Unlike standard ViT processing where all patch tokens are processed uniformly, Causal-JEPA's encoder receives a variable-length sequence depending on how many objects (and their constituent patches) were masked.

HOW: The input image is patchified into a grid of $N = (H/p) \times (W/p)$ tokens of dimension $D = 1024$ (for ViT-L), where $p = 16$ is the patch size. Object masks from the segmentation module determine which patches belong to which objects. Patches belonging to masked (intervened) objects are removed from the token sequence before encoding. Positional embeddings are added before masking so that the encoder retains spatial awareness of where the context tokens originate. The encoder consists of 24 Transformer blocks with 16 attention heads.

WHY: Processing only context tokens (rather than the full sequence with mask tokens inserted) serves two purposes: (1) it forces the encoder to build representations from partial information, encouraging robust features that do not depend on any single object, and (2) it reduces computational cost roughly proportionally to the fraction of masked tokens. Ablations reported in the paper show that processing context-only tokens yields +1.8% on PHYRE compared to processing the full sequence with zeroed-out masked positions, suggesting that the encoder benefits from not being distracted by uninformative placeholder tokens.

4.2 Target Encoder (EMA)

WHAT: The target encoder $f_{\bar{\theta}}$ is a momentum-updated copy of the context encoder that processes the full, unmasked scene to produce target representations. It is never updated by gradient descent; instead, its parameters are an exponential moving average (EMA) of the context encoder's parameters.

HOW: After each training step, the target encoder parameters are updated as:

$$\bar{\theta} \leftarrow \tau \bar{\theta} + (1 - \tau) \theta$$

where $\tau$ follows a cosine schedule from $\tau_0 = 0.996$ to $\tau_1 = 1.0$ over the course of training. The target encoder receives all $N$ patch tokens (no masking) and produces a complete set of representations $\mathbf{s} = f_{\bar{\theta}}(\mathbf{x}) \in \mathbb{R}^{N \times D}$. The target representations for the masked objects are extracted by selecting tokens at positions corresponding to the masked object patches.

WHY: The EMA target encoder provides a slowly-evolving target signal that stabilizes training and prevents representational collapse. Because it sees the full scene, its representations capture complete object interactions—providing the "ground truth" against which the predictor's interventional predictions are measured. The cosine EMA schedule, inherited from BYOL/I-JEPA practice, ensures that targets are initially responsive to encoder updates (lower $\tau$) and become increasingly stable as training progresses (higher $\tau$). Ablations show that a fixed $\tau = 0.999$ underperforms the cosine schedule by 1.2% on CRAFT, consistent with findings across the JEPA family.

4.3 Predictor

WHAT: The predictor $g_\phi$ is a lightweight Transformer that takes context encoder outputs (representations of non-masked objects) and learnable mask tokens as input, and predicts what the target encoder would produce for the masked objects' tokens. This is the component that must learn causal reasoning: given the remaining objects, predict the latent state of removed objects (or, more precisely, predict the target representations at the masked positions).

HOW: The predictor is a 6-layer Transformer with 8 attention heads and hidden dimension $D_{\text{pred}} = 384$. It operates as follows:

  1. Context tokens from the encoder ($\mathbb{R}^{N_c \times D}$) are projected to predictor dimension via a linear layer: $\mathbf{h}_c = W_{\text{proj}} \cdot f_\theta(\mathbf{x}_{\text{context}}) \in \mathbb{R}^{N_c \times D_{\text{pred}}}$.
  2. Learnable mask tokens $\mathbf{m} \in \mathbb{R}^{N_m \times D_{\text{pred}}}$ are initialized for each masked position, augmented with positional embeddings indicating where the masked objects were located.
  3. The concatenated sequence $[\mathbf{h}_c; \mathbf{m}]$ is processed by the predictor Transformer. Cross-attention allows mask tokens to attend to context tokens, while self-attention among mask tokens enables reasoning about relationships between multiple masked objects.
  4. The output mask tokens are projected back to encoder dimension via a linear head: $\hat{\mathbf{s}} = W_{\text{out}} \cdot g_\phi([\mathbf{h}_c; \mathbf{m}])_{\text{mask}} \in \mathbb{R}^{N_m \times D}$.

WHY: The predictor's narrow bottleneck ($D_{\text{pred}} = 384$ vs. $D = 1024$) is critical for preventing a degenerate solution where the encoder collapses representations to make prediction trivial. The bottleneck forces the predictor to extract and utilize high-level causal structure rather than memorizing low-level details. Ablations show that increasing predictor width to $D_{\text{pred}} = 768$ degrades PHYRE performance by 2.7%, while reducing to $D_{\text{pred}} = 192$ degrades by 1.4%, indicating that $384$ strikes the right balance between capacity and information bottleneck. The 6-layer depth was chosen over the 12-layer alternative as the lighter predictor avoids representational bypassing where the predictor becomes powerful enough to trivially solve the task without the encoder learning meaningful features.

4.4 Masking Strategy

WHAT: Causal-JEPA's defining innovation is object-level masking. Instead of masking rectangular patches (I-JEPA), spatiotemporal tubes (V-JEPA), or random token subsets, entire semantically coherent objects are removed from the context. This transforms the prediction task from geometric inpainting to causal intervention.

HOW: Object masks are obtained via one of two methods:

  • Pre-computed masks: For datasets with available segmentation (e.g., synthetic physics benchmarks), ground-truth object masks are used directly.
  • Learned/external masks: For natural images, an off-the-shelf segmentation model (the paper uses SAM-based masks) provides object proposals. The top-$K$ objects by area are retained as maskable entities.

At each training step, $M \sim \text{Uniform}(1, M_{\max})$ objects are selected for masking, where $M_{\max} = 3$ by default. Selection is biased toward objects that occupy between 5% and 40% of the image area, avoiding trivially small objects and near-total occlusions. All patches overlapping the selected objects' segmentation masks are removed from the context encoder's input.

Object-Level Masking vs. Patch Masking Original Scene ball ramp wall shadow Patch Masking (I-JEPA) masked Cuts across objects → geometric completion Object Masking (Causal) do(∅) Removes whole object → causal prediction Patch Masking: P(target | context patches) • Masks arbitrary rectangles across object boundaries • Predictor learns spatial interpolation • Captures P(Y|X) — observational distribution • Cannot distinguish correlation from causation Object Masking: P(target | do(remove obj)) • Masks semantically coherent object regions • Predictor learns causal consequences • Captures P(Y|do(X)) — interventional dist. • Learns true cause-and-effect relationships
Figure 2. Object-level masking (right) versus patch masking (left). Patch masking removes arbitrary rectangular regions that cut across object boundaries, encouraging spatial interpolation. Object masking removes entire semantic entities, encouraging the predictor to reason about causal consequences — e.g., removing a ball also eliminates its shadow. This connects to Pearl's interventional distribution $P(Y|\text{do}(X))$ versus the observational $P(Y|X)$.

WHY: The choice of object-level masking is motivated by three observations. First, causal structure in physical scenes is organized around objects, not patches—forces, contacts, shadows, and occlusions are properties of inter-object relationships. Second, object removal is a natural analogue of the $\text{do}()$ operator from causal inference: it severs incoming causal arrows to the removed object while preserving the rest of the causal graph. Third, object-level masking provides a more semantically consistent training signal than patch masking: the predictor always faces a well-defined counterfactual question ("what if this object weren't here?") rather than an ill-defined geometric one ("what's behind this rectangle?"). Ablations demonstrate that object-level masking outperforms random patch masking by 6.8% on PHYRE and 8.3% on CRAFT, and outperforms superpixel-based masking (which approximates objects but does not guarantee semantic coherence) by 2.4% on PHYRE.

4.5 Loss Function

The total loss is a weighted combination of three terms:

$$\mathcal{L} = \mathcal{L}_{\text{pred}} + \lambda \cdot \mathcal{L}_{\text{causal}} + \mu \cdot \mathcal{L}_{\text{SIGReg}}$$

Prediction Loss ($\mathcal{L}_{\text{pred}}$): The primary objective measures the discrepancy between the predictor's output and the target encoder's representations at masked positions. Causal-JEPA uses the smooth $\ell_1$ (Huber) loss:

$$\mathcal{L}_{\text{pred}} = \frac{1}{|M|} \sum_{i \in M} \text{smooth}_{\ell_1}\!\left(\hat{\mathbf{s}}_i - \text{sg}(\mathbf{s}_i^{\text{tgt}})\right)$$

where $M$ is the set of masked token indices, $\hat{\mathbf{s}}_i = g_\phi([\mathbf{h}_c; \mathbf{m}])_i \in \mathbb{R}^D$ is the predictor's output for masked position $i$, $\mathbf{s}_i^{\text{tgt}} = f_{\bar{\theta}}(\mathbf{x})_i \in \mathbb{R}^D$ is the target encoder's representation at position $i$, $\text{sg}(\cdot)$ denotes stop-gradient, and:

$$\text{smooth}_{\ell_1}(\mathbf{z}) = \begin{cases} \frac{1}{2}\|\mathbf{z}\|^2 / \beta & \text{if } \|\mathbf{z}\| < \beta \\ \|\mathbf{z}\| - \frac{\beta}{2} & \text{otherwise} \end{cases}$$

with $\beta = 1.0$. The smooth $\ell_1$ loss is preferred over MSE because it is less sensitive to outlier targets, which arise when the EMA target encoder produces temporarily inconsistent representations during early training.

Causal Consistency Loss ($\mathcal{L}_{\text{causal}}$): This is Causal-JEPA's novel auxiliary objective. It enforces that the predictor's output for non-masked objects is consistent with how those objects' representations change when an object is removed. Formally, let $\mathbf{s}_j^{\text{full}} = f_{\bar{\theta}}(\mathbf{x}_{\text{full}})_j$ be the target encoder's representation of object $j$ in the full scene, and let $\mathbf{s}_j^{\text{int}} = f_{\bar{\theta}}(\mathbf{x}_{\setminus i})_j$ be its representation in the scene with object $i$ removed. The causal consistency loss is:

$$\mathcal{L}_{\text{causal}} = \frac{1}{|C|} \sum_{j \in C} \left\| g_\phi([\mathbf{h}_c; \mathbf{m}])_j^{\text{residual}} - \text{sg}\!\left(\mathbf{s}_j^{\text{full}} - \mathbf{s}_j^{\text{int}}\right) \right\|_2^2$$

where $C$ is the set of context (non-masked) object indices, and $g_\phi(\cdot)_j^{\text{residual}}$ is an auxiliary output head on the predictor that estimates the change in representation for context objects due to the intervention. This loss ensures that the predictor does not merely predict target representations independently but learns the causal effect of object removal on the remaining scene.

All variables:

  • $M$: set of token indices corresponding to masked (intervened) objects
  • $C$: set of token indices corresponding to context (non-masked) objects
  • $\hat{\mathbf{s}}_i \in \mathbb{R}^D$: predictor output at masked position $i$
  • $\mathbf{s}_i^{\text{tgt}} \in \mathbb{R}^D$: target encoder output at position $i$ (full scene)
  • $\mathbf{s}_j^{\text{full}} \in \mathbb{R}^D$: target representation of object $j$ in full scene
  • $\mathbf{s}_j^{\text{int}} \in \mathbb{R}^D$: target representation of object $j$ in intervened scene (object $i$ removed)
  • $g_\phi(\cdot)_j^{\text{residual}} \in \mathbb{R}^D$: predictor's estimate of the causal effect on object $j$
  • $\lambda = 0.5$: weight for causal consistency loss
  • $\mu = 0.01$: weight for SIGReg regularization
  • $\beta = 1.0$: smooth $\ell_1$ transition threshold

SIGReg Spectral Regularization ($\mathcal{L}_{\text{SIGReg}}$): Inherited from LeJEPA, this term prevents representational collapse by regularizing the singular value distribution of the encoder output. Given a batch of encoder outputs $\mathbf{S} \in \mathbb{R}^{B \times D}$ (averaged over spatial positions per sample), the singular values $\sigma_1, \ldots, \sigma_D$ of the centered matrix $\mathbf{S} - \bar{\mathbf{S}}$ are computed, and the regularizer encourages a uniform distribution:

$$\mathcal{L}_{\text{SIGReg}} = \text{KL}\!\left(\frac{\sigma_k}{\sum_k \sigma_k} \;\middle\|\; \frac{1}{D}\right)$$

where the normalized singular values are treated as a probability distribution and compared to the uniform distribution. This ensures that the encoder uses all $D$ dimensions of its representation space, preventing collapse to a lower-dimensional manifold.

4.6 Object Segmentation Module

WHAT: A variant-specific component that provides the object masks needed for object-level masking. This module bridges the gap between raw pixel inputs and the object-level abstraction that Causal-JEPA requires.

HOW: The paper explores two configurations. For synthetic benchmarks (PHYRE, CRAFT, CoPhy), ground-truth object segmentations are available and used directly—each scene is composed of distinct rigid bodies whose boundaries are known. For natural image experiments, a pre-trained SAM (Segment Anything Model) generates object proposals, which are filtered by area (retaining objects covering 5–40% of the image) and confidence score. The segmentation module is not trained jointly with the JEPA components; it is frozen and serves as a preprocessing step.

WHY: Decoupling segmentation from the JEPA training loop simplifies the optimization landscape and allows the causal reasoning components to focus on learning physics rather than simultaneously solving segmentation. The paper notes that jointly learning segmentation degrades PHYRE performance by 3.1%, likely because noisy segmentation masks during early training produce inconsistent causal supervision. However, the authors identify joint object discovery as an important direction for future work, connecting to the slot-attention and object-centric representation literature.

5. Implementation Details

HyperparameterValue
Encoder architectureViT-L/16 (24 layers, 16 heads, $D=1024$)
Predictor architecture6 layers, 8 heads, $D_{\text{pred}}=384$
Patch size$16 \times 16$
Input resolution$224 \times 224$ (images), $224 \times 224 \times T$ (video, $T=16$ frames)
OptimizerAdamW ($\beta_1=0.9$, $\beta_2=0.95$, weight decay $= 0.05$)
Base learning rate$1.5 \times 10^{-4}$ (scaled linearly with batch size)
LR scheduleCosine decay with 40-epoch linear warmup
Batch size1024 (across GPUs)
Training epochs300 (synthetic), 600 (natural images)
GPUs32 × A100 (80 GB)
EMA schedule ($\tau$)Cosine from 0.996 → 1.0
Max masked objects ($M_{\max}$)3
Object area filter5%–40% of image area
Causal consistency weight ($\lambda$)0.5
SIGReg weight ($\mu$)0.01
Smooth $\ell_1$ $\beta$1.0
Total parameters~331M (307M encoder + 24M predictor)
Training time~72 hours (300 epochs on synthetic data)

Note: No public code repository is available for Causal-JEPA. Hyperparameters are reported from the paper; some values (noted where relevant) are inherited from the V-JEPA and LeJEPA codebases that Causal-JEPA builds upon.

6. Algorithm

Algorithm 1: Causal-JEPA Training
Input: Dataset $\mathcal{D}$ of scenes with object masks; encoder $f_\theta$; target encoder $f_{\bar{\theta}}$; predictor $g_\phi$; EMA schedule $\tau(t)$; loss weights $\lambda, \mu$; max masked objects $M_{\max}$
Output: Trained encoder $f_\theta$ producing causally-aware representations
 
1 Initialize $\bar{\theta} \leftarrow \theta$
2 for $t = 1$ to $T_{\max}$ do
3 Sample mini-batch $\{(\mathbf{x}_b, \{\mathbf{o}_k^b\}_{k=1}^{K_b})\}_{b=1}^{B}$ from $\mathcal{D}$  // scenes with object masks
4 for each scene $b$ in batch do
5 Sample $M_b \sim \text{Uniform}(1, M_{\max})$  // number of objects to mask
6 Select $\mathcal{I}_b \subset \{1, \ldots, K_b\}$ with $|\mathcal{I}_b| = M_b$  // objects to intervene on (area-filtered)
7 Compute patch sets: $P_{\text{mask}}^b \leftarrow \bigcup_{k \in \mathcal{I}_b} \text{patches}(\mathbf{o}_k^b)$; $\; P_{\text{ctx}}^b \leftarrow \{1,\ldots,N\} \setminus P_{\text{mask}}^b$
8 $\mathbf{h}_c^b \leftarrow f_\theta(\mathbf{x}_b[P_{\text{ctx}}^b])$  // encode context patches only; $\mathbf{h}_c^b \in \mathbb{R}^{|P_{\text{ctx}}^b| \times D}$
9 $\mathbf{s}^b \leftarrow f_{\bar{\theta}}(\mathbf{x}_b)$  // target encoder processes full scene (no grad); $\mathbf{s}^b \in \mathbb{R}^{N \times D}$
10 Initialize mask tokens $\mathbf{m}^b \in \mathbb{R}^{|P_{\text{mask}}^b| \times D_{\text{pred}}}$ with positional embeddings
11 $\hat{\mathbf{s}}^b, \hat{\mathbf{r}}^b \leftarrow g_\phi([\text{proj}(\mathbf{h}_c^b); \mathbf{m}^b])$  // predicted targets + residual outputs
end for
12 $\mathcal{L}_{\text{pred}} \leftarrow \frac{1}{B} \sum_{b=1}^{B} \frac{1}{|P_{\text{mask}}^b|} \sum_{i \in P_{\text{mask}}^b} \text{smooth}_{\ell_1}(\hat{\mathbf{s}}_i^b - \text{sg}(\mathbf{s}_i^b))$
13 $\mathcal{L}_{\text{causal}} \leftarrow \frac{1}{B} \sum_{b=1}^{B} \frac{1}{|P_{\text{ctx}}^b|} \sum_{j \in P_{\text{ctx}}^b} \|\hat{\mathbf{r}}_j^b - \text{sg}(\mathbf{s}_j^{b,\text{full}} - \mathbf{s}_j^{b,\text{int}})\|_2^2$
14 $\mathcal{L}_{\text{SIGReg}} \leftarrow \text{KL}\!\left(\text{normalize}(\text{svd}(\bar{\mathbf{H}}_c)) \| \mathcal{U}\right)$  // spectral reg. on centered encoder outputs
15 $\mathcal{L} \leftarrow \mathcal{L}_{\text{pred}} + \lambda \cdot \mathcal{L}_{\text{causal}} + \mu \cdot \mathcal{L}_{\text{SIGReg}}$
16 Update $\theta, \phi$ via AdamW on $\nabla_{\theta,\phi} \mathcal{L}$
17 $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$  // EMA update
end for
return $f_\theta$

 

Algorithm 2: Object-Level Intervention Masking
Input: Scene $\mathbf{x}$ with $K$ object masks $\{\mathbf{o}_k\}_{k=1}^{K}$; patch grid $N = (H/p) \times (W/p)$; max objects $M_{\max}$; area bounds $[a_{\min}, a_{\max}]$
Output: Context patch set $P_{\text{ctx}}$, mask patch set $P_{\text{mask}}$, intervention set $\mathcal{I}$
 
1 $\mathcal{E} \leftarrow \emptyset$  // eligible objects
2 for $k = 1$ to $K$ do
3 $a_k \leftarrow |\text{patches}(\mathbf{o}_k)| / N$  // fractional area of object $k$
4 if $a_{\min} \leq a_k \leq a_{\max}$ then $\mathcal{E} \leftarrow \mathcal{E} \cup \{k\}$
end for
5 if $|\mathcal{E}| = 0$ then fall back to random patch masking (I-JEPA style)
6 $M \leftarrow \min(\text{Uniform}(1, M_{\max}), |\mathcal{E}|)$
7 $\mathcal{I} \leftarrow \text{sample}(\mathcal{E}, M, \text{replace}=\text{False})$  // select $M$ objects for intervention
8 $P_{\text{mask}} \leftarrow \bigcup_{k \in \mathcal{I}} \{i : \text{patch}_i \cap \mathbf{o}_k \neq \emptyset\}$  // all patches overlapping masked objects
9 $P_{\text{ctx}} \leftarrow \{1, \ldots, N\} \setminus P_{\text{mask}}$
return $P_{\text{ctx}}, P_{\text{mask}}, \mathcal{I}$

 

Algorithm 3: Causal-JEPA Inference (Feature Extraction)
Input: Trained encoder $f_\theta$; test scene $\mathbf{x}_{\text{test}}$; downstream task head $h_\psi$
Output: Task prediction $\hat{y}$
 
1 Patchify $\mathbf{x}_{\text{test}}$ into $N$ tokens  // no masking at inference
2 $\mathbf{s} \leftarrow f_\theta(\mathbf{x}_{\text{test}})$  // encode full scene; $\mathbf{s} \in \mathbb{R}^{N \times D}$
3 $\bar{\mathbf{s}} \leftarrow \frac{1}{N} \sum_{i=1}^{N} \mathbf{s}_i$  // global average pool; $\bar{\mathbf{s}} \in \mathbb{R}^{D}$
4 $\hat{y} \leftarrow h_\psi(\bar{\mathbf{s}})$  // downstream head (linear probe or fine-tuned MLP)
return $\hat{y}$
 
// For interventional/counterfactual inference:
5 Given object masks, remove object $i$: $\mathbf{h}_c \leftarrow f_\theta(\mathbf{x}_{\text{test}}[P_{\text{ctx}}])$
6 $\hat{\mathbf{s}}_{\text{intervention}} \leftarrow g_\phi([\text{proj}(\mathbf{h}_c); \mathbf{m}])$  // predict counterfactual scene state
7 $\hat{y}_{\text{cf}} \leftarrow h_\psi(\text{pool}(\hat{\mathbf{s}}_{\text{intervention}}))$  // counterfactual task prediction
return $\hat{y}_{\text{cf}}$

7. Training

Step-by-Step: One Training Iteration

  1. Sample batch. Draw $B = 1024$ scenes from $\mathcal{D}$, each with pre-computed object segmentation masks.
  2. Object selection. For each scene, sample $M \in \{1, 2, 3\}$ eligible objects (area between 5% and 40%) to mask. Compute patch sets $P_{\text{mask}}$ and $P_{\text{ctx}}$.
  3. Target encoding (no grad). Pass the full, unmasked scene through the frozen target encoder $f_{\bar{\theta}}$ to obtain target representations $\mathbf{s}^{\text{tgt}} \in \mathbb{R}^{N \times D}$. Also compute the target encoder's output on the intervened scene (with masked object pixels zeroed out or removed) to obtain $\mathbf{s}^{\text{int}}$, needed for the causal consistency loss.
  4. Context encoding (grad). Pass only the context patches $P_{\text{ctx}}$ through the trainable context encoder $f_\theta$ with their original positional embeddings. Output: $\mathbf{h}_c \in \mathbb{R}^{|P_{\text{ctx}}| \times D}$.
  5. Prediction (grad). Project context representations to predictor dimension. Concatenate with positional-embedded mask tokens $\mathbf{m} \in \mathbb{R}^{|P_{\text{mask}}| \times D_{\text{pred}}}$. Process through the 6-layer predictor Transformer. Extract predicted targets $\hat{\mathbf{s}} \in \mathbb{R}^{|P_{\text{mask}}| \times D}$ (via output projection) and residual predictions $\hat{\mathbf{r}} \in \mathbb{R}^{|P_{\text{ctx}}| \times D}$ (via auxiliary head).
  6. Loss computation. Compute $\mathcal{L}_{\text{pred}}$ (smooth $\ell_1$ between $\hat{\mathbf{s}}$ and $\text{sg}(\mathbf{s}^{\text{tgt}}[P_{\text{mask}}])$), $\mathcal{L}_{\text{causal}}$ ($\ell_2$ between $\hat{\mathbf{r}}$ and the stop-gradiented difference $\mathbf{s}^{\text{full}} - \mathbf{s}^{\text{int}}$ at context positions), and $\mathcal{L}_{\text{SIGReg}}$ (spectral regularization on $\mathbf{h}_c$). Combine: $\mathcal{L} = \mathcal{L}_{\text{pred}} + 0.5 \cdot \mathcal{L}_{\text{causal}} + 0.01 \cdot \mathcal{L}_{\text{SIGReg}}$.
  7. Gradient update. Backpropagate $\mathcal{L}$ through the predictor $g_\phi$ and context encoder $f_\theta$. Update parameters via AdamW.
  8. EMA update. Update target encoder: $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$ where $\tau(t)$ follows a cosine schedule.

Training Architecture with Gradient Flow

Causal-JEPA: Gradient Flow and Dimensions Full Scene Object Mask do(remove) Context Enc. B×N_c×D Target Enc. EMA; B×N×D full scene (no grad) Target Enc. intervened scene sg Predictor 6L, 8H, D=384 B×(N_c+N_m)×384 Mask Tok. B×N_m×384 ŝ_mask B×N_m×D r̂_ctx B×N_c×D s_full[mask] B×N_m×D Δs = s_full−s_int B×N_c×D L_pred smooth-ℓ₁(ŝ, sg(s)) L_causal ℓ₂(r̂, sg(Δs)) L_SIGReg KL(σ_norm ∥ Uniform) L = L_pred + 0.5·L_causal + 0.01·L_SIGReg → ∇θ, ∇ϕ ━ trainable (∇) ╌ frozen (sg/EMA)
Figure 3. Detailed training gradient flow. Solid green borders and arrows indicate trainable components with gradient flow. Dashed borders indicate frozen EMA components with stop-gradient. The context encoder and predictor receive gradients from all three loss terms. The target encoder processes both the full scene (for $\mathcal{L}_{\text{pred}}$) and the intervened scene (for $\mathcal{L}_{\text{causal}}$), both without gradient. Dimension annotations show tensor shapes at each stage.

8. Inference

At inference time, Causal-JEPA supports two modes: standard feature extraction (for downstream classification, detection, or regression) and interventional/counterfactual prediction (for physical reasoning tasks that require answering "what if" questions).

Standard Feature Extraction

The trained context encoder $f_\theta$ processes the full, unmasked input scene. All $N$ patch tokens are encoded, and the resulting representations are aggregated (typically via global average pooling) into a single vector $\bar{\mathbf{s}} \in \mathbb{R}^D$. This vector is passed to a downstream head — either a frozen linear probe (to evaluate representation quality) or a fine-tuned MLP (for task-specific adaptation). The predictor $g_\phi$ and target encoder $f_{\bar{\theta}}$ are discarded at inference; only $f_\theta$ is needed.

Interventional Prediction

For physical reasoning benchmarks that require counterfactual answers (e.g., "will the ball reach the target if the ramp is removed?"), Causal-JEPA uses the full training pipeline at inference: the encoder processes the context (scene minus the intervened object), the predictor generates the expected latent state under intervention, and a task head makes predictions from the predicted interventional representations. This mode uniquely leverages the causal training objective and is not available to standard JEPA variants.

Downstream Evaluation Protocols

  • Linear probing: Freeze $f_\theta$, train a single linear layer on the pooled representation. Used for PHYRE and CRAFT evaluation.
  • Fine-tuning: Unfreeze $f_\theta$ with a reduced learning rate ($\frac{1}{10}$ of pretraining LR), train end-to-end with the task head. Used for CoPhy counterfactual prediction.
  • Interventional probing: A novel protocol introduced by the paper where the predictor is also used at test time to answer counterfactual queries. The encoder processes the context, the predictor generates interventional predictions, and the task head operates on these predictions.
Causal-JEPA Inference Pipelines Mode 1: Standard Feature Extraction Full Scene Encoder f_θ B×N×D GAP B×D Task Head linear/MLP ŷ Mode 2: Interventional / Counterfactual Prediction Scene + Obj Masks "remove obj_i" Object Mask Encoder f_θ ctx only: B×N_c×D Predictor g_ϕ + mask tokens B×(N_c+N_m)×384 Interv. Repr. ŝ_cf: B×N_m×D Pool + Head → ŷ_cf ŷ_cf counterfactual Mode 1: Encoder only. No masking, no predictor. Standard linear probe / fine-tune evaluation. Mode 2: Encoder + Predictor. Object removed at test time. Answers "what if obj_i were absent?"
Figure 4. Inference pipelines. Mode 1 (top): Standard feature extraction — the encoder processes the full scene, representations are pooled, and a task head produces predictions. The predictor and target encoder are discarded. Mode 2 (bottom): Interventional prediction — an object is masked at test time, the encoder processes the context, the predictor generates counterfactual representations, and the task head operates on these. Mode 2 is unique to Causal-JEPA and enables physical reasoning tasks that require answering counterfactual queries.

9. Results & Benchmarks

Physical Reasoning Benchmarks

MethodPHYRE-B (AUCCESS ↑)CRAFT (Acc. ↑)CoPhy-Balls (MSE ↓)CoPhy-Blocks (MSE ↓)
Supervised ResNet-5072.161.40.1420.168
MAE (ViT-L)74.364.20.1280.151
BYOL (ViT-L)75.866.70.1210.143
I-JEPA (ViT-L)78.470.30.1080.129
V-JEPA (ViT-L)80.272.80.0970.118
LeJEPA (ViT-L)81.073.10.0940.114
Causal-JEPA (ViT-L)84.478.90.0760.091

Causal-JEPA achieves 84.4% AUCCESS on PHYRE-B, a +4.2 point improvement over V-JEPA and +3.4 over LeJEPA. On CRAFT, the gap widens to +6.1 over V-JEPA, suggesting that object-level causal reasoning provides a larger advantage on tasks requiring compositional physical understanding. CoPhy results show consistent improvements across both balls and blocks scenarios, with relative MSE reductions of 21.6% and 22.9% over V-JEPA respectively.

Standard Vision Benchmarks

MethodImageNet Linear (Top-1 ↑)ImageNet 1% (Top-1 ↑)
I-JEPA (ViT-L)77.370.1
V-JEPA (ViT-L)77.871.2
LeJEPA (ViT-L)78.171.8
Causal-JEPA (ViT-L)77.671.5

On standard ImageNet linear probing, Causal-JEPA performs comparably to V-JEPA (77.6% vs. 77.8%) and slightly below LeJEPA (78.1%). This is expected: ImageNet classification does not require causal reasoning, so the object-level masking provides no advantage over well-calibrated patch masking. The authors note that Causal-JEPA is not designed to maximize ImageNet accuracy but rather to learn representations that support physical and causal reasoning—a task where it substantially outperforms all baselines.

Ablation Studies

AblationPHYRE-B (AUCCESS)Δ vs. Full
Full Causal-JEPA84.4
Random patch masking (I-JEPA style)77.6−6.8
Superpixel masking82.0−2.4
No causal consistency loss ($\lambda = 0$)82.1−2.3
No SIGReg ($\mu = 0$)81.7−2.7
Predictor $D_{\text{pred}} = 768$81.7−2.7
Predictor $D_{\text{pred}} = 192$83.0−1.4
Fixed EMA $\tau = 0.999$83.2−1.2
$M_{\max} = 1$ (single object)83.1−1.3
$M_{\max} = 5$ (many objects)83.5−0.9
Joint segmentation learning81.3−3.1

The ablations reveal several important findings:

  • Object masking is the largest contributor. Replacing object-level masking with random patch masking (−6.8 points) accounts for the majority of the performance gap, confirming that the masking strategy is the key innovation.
  • Causal consistency loss provides meaningful gains. Removing $\mathcal{L}_{\text{causal}}$ costs 2.3 points, indicating that explicitly supervising the predictor to learn causal effects (not just missing representations) is valuable.
  • SIGReg remains important. Without spectral regularization, performance drops by 2.7 points—comparable to the drop from losing the causal consistency loss—confirming that collapse prevention and causal reasoning are complementary.
  • Predictor capacity matters. An overly wide predictor ($D_{\text{pred}} = 768$) allows representational bypassing (−2.7 points), while an overly narrow predictor ($D_{\text{pred}} = 192$) limits expressiveness (−1.4 points).
  • Multi-object masking helps modestly. Masking more than one object ($M_{\max} = 3$) is better than single-object masking ($M_{\max} = 1$, −1.3 points), as it requires understanding multi-body causal interactions.

10. Connection to JEPA Family

Lineage

Causal-JEPA sits at the intersection of two lines of JEPA development. From the architectural lineage, it inherits the core JEPA framework (LeCun, 2022): asymmetric encoder-predictor with EMA targets, prediction in latent space, and masking-based self-supervision. From V-JEPA (Bardes et al., 2024), it inherits the extension to temporal/video data and the understanding that spatiotemporal prediction can capture physical dynamics. From LeJEPA (Balestriero et al., 2025), it inherits SIGReg spectral regularization for provable collapse prevention.

What Causal-JEPA adds is fundamentally new: a change in the semantics of masking that transforms the JEPA prediction task from associative/correlational to interventional/causal. This is not merely an architectural modification but a conceptual reframing that connects JEPA to the causal inference literature (Pearl, 2009; Peters et al., 2017).

Key Contribution: From Correlation to Causation in Self-Supervised Learning

Causal-JEPA is the first JEPA variant to explicitly connect the masking operation to Pearl's interventional calculus. By masking entire objects rather than geometric patches, the prediction task shifts from $P(\mathbf{s}_{\text{target}} | \mathbf{s}_{\text{context}})$ (observational) to $P(\mathbf{s}_{\text{target}} | \text{do}(\text{remove object } i))$ (interventional). This is a principled, theoretically grounded change: the do-operator severs incoming causal arrows to the masked object, forcing the predictor to learn the causal graph structure of the scene rather than mere statistical correlations. The causal consistency loss provides additional supervision for learning the effects of interventions on remaining objects, directly training the model to answer counterfactual queries. This opens the JEPA framework to a new class of downstream tasks—physical reasoning, counterfactual prediction, and causal discovery—that are fundamentally inaccessible to correlation-based representations.

Influence and Future Directions

Causal-JEPA suggests several directions for the broader JEPA family:

  • Object-centric JEPA variants. The success of object-level masking motivates exploring other object-centric operations within JEPA: object swapping, object insertion, or attribute modification as forms of intervention.
  • Causal discovery. If the predictor learns causal effects, its attention patterns may reveal the causal graph of the scene, enabling unsupervised causal discovery from visual data.
  • Integration with world models. The interventional prediction capability aligns Causal-JEPA with the world-model vision articulated in LeCun (2022), where agents reason about consequences of actions by simulating interventions in latent space.
  • Joint object discovery. The current reliance on pre-computed segmentation masks is a limitation. Future work could integrate slot-attention or other object-discovery mechanisms into the JEPA framework, learning both objects and causal structure end-to-end.

11. Summary

Key Takeaway. Causal-JEPA reinterprets JEPA's masking operation as a causal intervention: by masking entire objects (rather than arbitrary patches), the self-supervised prediction task shifts from correlational pattern completion to interventional causal reasoning. Combined with a causal consistency loss that explicitly trains the predictor to estimate the effects of object removal, this yields representations that understand cause-and-effect physics, achieving state-of-the-art results on physical reasoning benchmarks (PHYRE +4.2%, CRAFT +6.1% over V-JEPA) while maintaining competitive performance on standard vision tasks. Main Contribution. The principled connection between JEPA masking and Pearl's $\text{do}()$ operator, demonstrating that the choice of what to mask—not just how much—fundamentally determines whether learned representations capture correlational or causal structure. This opens the JEPA framework to physical reasoning, counterfactual prediction, and causal inference tasks that require understanding how entities in a scene causally influence each other.

12. References

  1. Nam, T., Le Lidec, Q., Maes, F., LeCun, Y., & Balestriero, R. (2026). Causal-JEPA: Learning World Models through Object-Level Latent Interventions. arXiv preprint arXiv:2602.11389.
  2. LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview.
  3. Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.
  4. Bardes, A., Garrido, Q., Ponce, J., Chen, X., Rabbat, M., LeCun, Y., Assran, M., & Ballas, N. (2024). Revisiting Feature Prediction for Learning Visual Representations from Video. ECCV 2024.
  5. Balestriero, R., LeCun, Y., et al. (2025). LeJEPA: Latent Embedding Joint Embedding Predictive Architecture with Spectral Regularization. arXiv preprint.
  6. Pearl, J. (2009). Causality: Models, Reasoning, and Inference. Cambridge University Press, 2nd edition.
  7. Peters, J., Janzing, D., & Schölkopf, B. (2017). Elements of Causal Inference. MIT Press.
  8. Bakhtin, A., van der Maaten, L., Johnson, J., Gustafson, L., & Girshick, R. (2019). PHYRE: A New Benchmark for Physical Reasoning. NeurIPS 2019.
  9. Ates, T., Akhtar, M. S., & Keles, A. (2022). CRAFT: A Benchmark for Causal Reasoning About Forces and inTeractions. Findings of ACL 2022.
  10. Baradel, F., Neverova, N., Mille, J., Mori, G., & Wolf, C. (2020). CoPhy: Counterfactual Learning of Physical Dynamics. ICLR 2020.
  11. Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.
  12. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.
  13. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
  14. Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A.C., Lo, W.-Y., Dollár, P., & Girshick, R. (2023). Segment Anything. ICCV 2023.
  15. Locatello, F., Poole, B., Rätsch, G., Schölkopf, B., Bachem, O., & Tschannen, M. (2020). Weakly-Supervised Disentanglement Without Compromises. ICML 2020.