AuthorsUnknown
Date2026-03
CategoryPhysics / World Models
Derives fromV-JEPA 2, LeJEPA

1. Introduction

Joint-Embedding Predictive Architectures (JEPAs) have established a powerful paradigm for self-supervised representation learning: predict missing information in a learned latent space rather than in pixel space. Models such as I-JEPA and V-JEPA demonstrated that masking-based prediction in representation space yields semantically rich features for images and short video clips. V-JEPA-2 extended this to large-scale video understanding with dense spatiotemporal prediction, while LeJEPA contributed principled spectral regularization (SIGReg) that provably prevents representation collapse. Yet a fundamental limitation persists across the entire JEPA family: these models capture local physical dynamics but cannot reason about long-horizon goals, causal relationships between distant events, or strategic task decomposition. A V-JEPA-2 world model can predict what happens physically in the next few frames, but it cannot answer why an agent should prefer one trajectory over another, or how a multi-step manipulation task should be decomposed into feasible sub-goals.

This gap is precisely what large Vision-Language Models (VLMs) excel at. Models such as GPT-4V, Gemini, and LLaVA demonstrate remarkable capacity for visual reasoning, goal specification, and step-by-step planning expressed in natural language. However, VLMs operate on discrete token sequences and lack the dense, continuous physical prediction capability needed for fine-grained motor control and physics-aware trajectory optimization. Their "understanding" of physics is linguistic and approximate, not grounded in learned dynamics.

ThinkJEPA (Zhu et al., 2026) bridges this divide by coupling a JEPA-based latent world model with a VLM reasoning module into a unified architecture for long-horizon strategic planning. The central insight is architectural complementarity: the JEPA world model provides physical intuition—dense, continuous predictions of how states evolve under actions—while the VLM provides strategic reasoning—the capacity to decompose a high-level task into ordered sub-goals, anticipate failure modes, and re-plan when predictions diverge from expectations.

ThinkJEPA's key contributions are:

  1. A dual-system architecture that formally integrates a latent world model (derived from V-JEPA-2) with a VLM reasoning module, connected via a learned Goal Projection Interface (GPI) that translates between language-specified sub-goals and latent-space target regions.
  2. Hierarchical planning: the VLM produces a chain of sub-goals in natural language; each sub-goal is projected into latent space as a target representation; the JEPA world model then performs dense rollouts to find action sequences that reach each target, yielding physically grounded multi-step plans.
  3. A reasoning-guided prediction mechanism where VLM attention maps and reasoning traces condition the world model predictor, allowing it to focus computational resources on task-relevant aspects of the scene during forward simulation.
  4. State-of-the-art results on long-horizon planning benchmarks, including CALVIN, Language-Table, and a new ThinkBench suite, demonstrating that the combination of physical prediction and language reasoning substantially outperforms either component in isolation.

Relative to its predecessors: V-JEPA-2 provides the physical prediction backbone but has no mechanism for goal specification or task decomposition; LeJEPA contributes the spectral regularization that keeps ThinkJEPA's representations well-conditioned, but does not address planning. ThinkJEPA synthesizes both and adds the entirely new VLM reasoning layer, representing a qualitative shift from perception to deliberative planning within the JEPA family.

2. Method

Consider how a human plans to make a sandwich. You don't simulate every muscle contraction from start to finish. Instead, you think at a high level—"get bread, spread peanut butter, add jelly, close sandwich"—and then for each step, you rely on physical intuition to actually execute the motion. If the jelly jar is stuck, you don't re-plan the entire sandwich; you adapt locally. ThinkJEPA formalizes exactly this dual-process reasoning.

Intuition: Two Systems of Thought. ThinkJEPA implements a computational analogue of dual-process theory from cognitive science. System 2 (the VLM) performs slow, deliberate, language-mediated reasoning to decompose a task into sub-goals. System 1 (the JEPA world model) performs fast, intuitive, continuous physical prediction to find motor plans that achieve each sub-goal. Neither system alone is sufficient: System 2 without System 1 produces plans that are physically implausible; System 1 without System 2 produces trajectories that are locally optimal but globally aimless.

The method proceeds in three stages:

Stage 1: World Model Pre-training. A V-JEPA-2-style encoder and predictor are trained on large-scale video data to learn dense physical prediction in latent space. The encoder maps video frames to patch-level representations; the predictor, conditioned on actions (when available) and context representations, forecasts future latent states. LeJEPA's SIGReg regularization is applied to the representation covariance to prevent dimensional collapse and ensure the latent space is well-distributed. This stage produces a world model that "understands" physics but has no notion of goals.

Stage 2: Goal Projection Interface Training. The Goal Projection Interface (GPI) is trained to translate between natural language sub-goal descriptions and regions in the world model's latent space. Given a language description like "the red block is on top of the blue block" and a set of video frames depicting that state, the GPI learns a mapping from the VLM's language embedding to a target representation in JEPA latent space. This is trained contrastively: matching language-state pairs should have high cosine similarity in a shared projection space, while non-matching pairs should be dissimilar.

Intuition: The Goal Projection Interface as a Translator. Imagine a project manager (the VLM) who speaks only English, directing an engineer (the JEPA world model) who works only with blueprints. The GPI is the bilingual translator who converts the manager's verbal instructions ("make the bridge support 10 tons") into precise blueprint specifications (target stress distributions in the structural model). Without this translator, neither party can effectively collaborate.

Stage 3: Joint Reasoning and Planning. At deployment, a task is specified in natural language (e.g., "stack the blocks in order of size"). The VLM reasons about this task—potentially through chain-of-thought—and produces an ordered list of sub-goals. Each sub-goal is projected through the GPI into a latent target. The JEPA world model then performs forward rollouts (using model-predictive control) to find action sequences that move the predicted latent state toward each target in sequence. If rollout predictions diverge significantly from expected sub-goal achievement, the VLM is re-queried for plan revision.

The crucial architectural decision is that the VLM and world model share no weights. The VLM is a pre-trained, frozen (or lightly fine-tuned) vision-language model; the JEPA world model is independently pre-trained. They communicate only through the GPI bottleneck. This separation preserves the strengths of each component and allows independent scaling.

3. Model Overview

At-a-Glance

PropertyThinkJEPA
InputVideo frames (RGB) + language task specification
MaskingSpatiotemporal block masking during world model pre-training (V-JEPA-2 style); N/A during planning
Vision EncoderViT-H/16 (V-JEPA-2 backbone), ~632M params
VLM Reasoning Module7B-parameter VLM (architecture similar to LLaVA-NeXT or Qwen-VL)
PredictorTransformer predictor with action conditioning, ~300M params
Goal Projection InterfaceCross-modal projection network, ~85M params
Loss (pre-training)Smooth-L1 prediction loss + SIGReg spectral regularization
Loss (GPI training)InfoNCE contrastive loss over language-state pairs
Key Result73.2% task success on CALVIN ABC→D (vs. 58.4% for V-JEPA-2 planning baseline)
Total Params~8.0B (including frozen VLM); ~1.0B trainable during world model + GPI stages
ThinkJEPA — Training Architecture Overview Video Frames T×3×224×224 Task Language "stack blocks..." Actions T×D_a Online Encoder ViT-H/16 (trainable) B×T×N×D Target Encoder EMA (frozen) B×T×N×D EMA World Model Predictor Transformer (trainable) B×N_tgt×D action cond. Prediction Loss Smooth-L1 + SIGReg ŝ_t+1 s*_t+1 VLM Reasoning 7B VLM (frozen/LoRA) sub-goals in language + attention maps Goal Projection Interface (GPI) trainable, ~85M Contrastive Loss InfoNCE (lang ↔ latent) target latent states ━━ trainable │ ╌╌ frozen/EMA │ Gradients flow through solid-border modules only Trainable Frozen / EMA
Figure 1: ThinkJEPA training architecture. The world model (left) is trained with Smooth-L1 + SIGReg prediction loss via V-JEPA-2-style spatiotemporal masking. The VLM reasoning module (right) is frozen or lightly adapted with LoRA. The Goal Projection Interface (GPI) is trained contrastively to align language sub-goals with latent-space target representations. Gradients flow only through solid-bordered modules.

4. Main Components of ThinkJEPA

4.1 Vision Encoder (Online Encoder)

WHAT: The online encoder $f_\theta$ is a Vision Transformer (ViT-H/16) that maps each video frame into a set of patch-level latent representations. For a frame $x_t \in \mathbb{R}^{3 \times H \times W}$, the encoder produces $s_t = f_\theta(x_t) \in \mathbb{R}^{N \times D}$ where $N$ is the number of spatial patches and $D$ is the embedding dimension.

HOW: The encoder uses patch size $16 \times 16$, yielding $N = (224/16)^2 = 196$ tokens per frame. Embedding dimension $D = 1280$ (ViT-Huge). The encoder consists of 32 transformer blocks with 16 attention heads. During world model pre-training, the encoder processes only the context tokens (unmasked patches from the current frame and possibly past frames), following the asymmetric masking paradigm of V-JEPA-2. Learnable spatiotemporal positional embeddings are added to each token.

WHY: The ViT-H/16 architecture is chosen for its capacity to learn rich spatial representations at scale. Ablations in the paper compare ViT-L/16 (~307M) and ViT-H/16 (~632M) encoders: the larger encoder improves CALVIN task success by 4.7 percentage points (68.5% → 73.2%), confirming that world model representation quality directly impacts downstream planning performance. The patch-level representation (rather than a single [CLS] token) is essential because the predictor and GPI require spatially resolved features to ground sub-goals to specific scene locations.

4.2 Target Encoder (EMA)

WHAT: The target encoder $\bar{f}_\xi$ is an exponential moving average (EMA) copy of the online encoder. It processes the full (unmasked) video frames to produce target representations $s^*_t = \bar{f}_\xi(x_t)$ used as prediction targets during world model training.

HOW: The EMA update rule is $\xi \leftarrow \tau \xi + (1 - \tau)\theta$, where the momentum coefficient $\tau$ follows a cosine schedule from $\tau_0 = 0.996$ to $\tau_1 = 1.0$ over the course of training. No gradients flow through the target encoder; it is updated only via the EMA mechanism. Target representations are normalized per-patch (L2 normalization) before being used as prediction targets.

WHY: The EMA target encoder provides a slowly evolving, stable set of prediction targets. This asymmetry between the rapidly updated online encoder and the slowly updated target encoder, combined with the predictor bottleneck, prevents representational collapse—the degenerate solution where both encoders learn to output a constant regardless of input. The cosine schedule for $\tau$ is critical: early training benefits from faster target updates ($\tau = 0.996$) that allow the target to track improving representations, while late training benefits from near-frozen targets ($\tau \approx 1.0$) that provide stable objectives. Ablations show that a fixed $\tau = 0.999$ throughout training reduces CALVIN success by 2.1 points compared to the cosine schedule.

4.3 World Model Predictor

WHAT: The predictor $g_\phi$ is a transformer network that, given context representations from the current state and an action, predicts the target representations of the next state. Formally: $\hat{s}_{t+1} = g_\phi(s_t^{\text{ctx}}, a_t, \mathbf{m})$ where $s_t^{\text{ctx}}$ are the context (unmasked) token representations, $a_t$ is the action vector, and $\mathbf{m}$ denotes learnable mask tokens placed at target positions.

HOW: The predictor is a 12-layer transformer with hidden dimension $D_p = 768$, 12 attention heads, and a linear projection from $D_p$ back to $D = 1280$ at the output. Action conditioning is implemented via FiLM (Feature-wise Linear Modulation): the action $a_t \in \mathbb{R}^{D_a}$ is projected through a 2-layer MLP to produce per-layer scale $\gamma_l$ and shift $\beta_l$ parameters that modulate the predictor's hidden representations after each layer normalization. Learnable mask tokens $\mathbf{m} \in \mathbb{R}^{D_p}$ are appended at target positions, and the predictor outputs predictions only at these positions. The narrow bottleneck ($D_p < D$) is intentional.

WHY: The predictor bottleneck ($D_p = 768$ vs. encoder $D = 1280$) ensures that the predictor cannot simply copy the encoder's representations through an identity mapping, forcing the encoder to learn representations that are inherently predictable—i.e., that capture the regularities of physical dynamics. FiLM conditioning for actions is chosen over concatenation or cross-attention because it modulates the predictor's internal computation at every layer rather than providing action information only at the input, yielding better multi-step rollout accuracy. Ablations show FiLM improves 5-step rollout cosine similarity by 0.08 over concatenation-based conditioning.

4.4 VLM Reasoning Module

WHAT: The VLM reasoning module $R_\psi$ is a 7-billion-parameter Vision-Language Model that takes as input the current visual observation and a natural language task description, and outputs: (1) a chain-of-thought reasoning trace, (2) an ordered sequence of sub-goals in natural language, and (3) visual attention maps indicating which scene regions are relevant to each sub-goal.

HOW: The VLM uses a CLIP-based vision encoder to process the current frame into visual tokens, which are interleaved with language tokens and processed by a 32-layer, 4096-dimensional language model backbone. During ThinkJEPA training, the VLM is either fully frozen or adapted via LoRA (rank 16, $\alpha = 32$) on task-specific planning datasets. The VLM is prompted with a structured template:

prompt = f"""You are a robot planning assistant. 
Current observation: [IMAGE]
Task: {task_description}
Decompose this task into ordered sub-goals. For each sub-goal, 
describe the target scene state and which objects are involved.
Output format: 
1. [sub-goal description] | [relevant objects]
2. [sub-goal description] | [relevant objects]
..."""

The VLM's visual attention maps from the last transformer layer are extracted as spatial relevance masks $\mathbf{w}_k \in \mathbb{R}^{N}$ for each sub-goal $k$, indicating which patches are most relevant.

WHY: A 7B VLM strikes a balance between reasoning capability and computational cost. Larger VLMs (e.g., 70B) improve sub-goal quality marginally (+1.3% task success on CALVIN) but increase inference latency by 8×, making real-time planning infeasible. The attention-map extraction provides a crucial bridge: rather than the VLM communicating with the world model only through discrete language tokens, the spatial attention maps provide continuous guidance about where in the scene each sub-goal is grounded. LoRA fine-tuning on robot planning data improves sub-goal decomposition quality (measured by human agreement) from 62% to 81% while updating only 0.3% of VLM parameters.

4.5 Goal Projection Interface (GPI)

WHAT: The GPI is a cross-modal projection network that maps VLM language embeddings and spatial attention maps into the JEPA world model's latent space, producing latent goal representations $g_k \in \mathbb{R}^{N \times D}$ for each sub-goal $k$.

HOW: The GPI has three sub-components: (1) a language projector $\pi_\text{lang}: \mathbb{R}^{D_\text{VLM}} \rightarrow \mathbb{R}^{D}$ that maps the VLM's sub-goal embedding to the world model's representation dimension; (2) a spatial modulation layer that combines the projected language embedding with the VLM's spatial attention map $\mathbf{w}_k$ to distribute the goal signal across patches: $g_k^{(n)} = \pi_\text{lang}(e_k) \cdot w_k^{(n)}$ where $e_k$ is the VLM embedding for sub-goal $k$ and $w_k^{(n)}$ is the attention weight for patch $n$; (3) a 4-layer transformer refinement module that takes the spatially distributed goal tokens and refines them through self-attention, producing the final latent goal $g_k$. Total parameters: ~85M.

WHY: The GPI must bridge two fundamentally different representation spaces: the VLM's discrete, semantic token space and the world model's continuous, physics-grounded latent space. A simple linear projection is insufficient (ablation: −8.3% task success). The spatial modulation step is critical because sub-goals are often localized ("pick up the red block") and the GPI must ground this to specific patch positions. The transformer refinement module allows cross-patch interaction so that relational sub-goals ("place A next to B") can be properly represented. Ablation of the refinement module costs 5.1% task success.

4.6 Masking Strategy (World Model Pre-training)

WHAT: During world model pre-training, ThinkJEPA uses spatiotemporal block masking following V-JEPA-2. Large contiguous regions of the spatiotemporal volume are masked from the online encoder (context) and must be predicted by the world model predictor, with the target encoder providing supervision.

HOW: For each training clip of $T = 16$ frames, 4 target blocks are sampled. Each target block spans a random spatial region of aspect ratio uniformly drawn from $[0.75, 1.5]$ and scale uniformly drawn from $[0.15, 0.7]$ of the total spatial area, and a temporal extent of 4–8 frames. The context consists of all patches not covered by any target block. The context masking ratio ranges from 65% to 90% across batches (the fraction of tokens masked).

ThinkJEPA — Spatiotemporal Block Masking t=1 t=4 t=8 t=12 t=16 T1 T2 T3 T4 Context (visible) Target (masked) Context tokens → Online Encoder → Predictor → predict targets Full frames → Target Encoder (EMA) → supervision signal Masking Hyperparameters Num targets: 4 | Spatial scale: U[0.15, 0.7] | Aspect ratio: U[0.75, 1.5] Temporal extent: 4–8 frames | Context mask ratio: 65–90% | Clip length: T=16
Figure 2: Spatiotemporal block masking during world model pre-training. Four target blocks (T1–T4) of varying spatial extent and temporal span are masked from the online encoder. The predictor must predict the target encoder's representations at these positions. This masking follows V-JEPA-2 conventions.

4.7 Loss Functions

World Model Prediction Loss

The world model is trained with a Smooth-L1 (Huber) loss between predicted and target representations at masked positions:

$$\mathcal{L}_{\text{pred}} = \frac{1}{|\mathcal{M}|} \sum_{(t,n) \in \mathcal{M}} \text{SmoothL1}\!\left(\hat{s}_{t}^{(n)},\; s^{*\,(n)}_{t}\right)$$

where $\mathcal{M}$ is the set of masked (target) spatiotemporal positions, $\hat{s}_{t}^{(n)} \in \mathbb{R}^D$ is the predictor's output for patch $n$ at time $t$, and $s^{*\,(n)}_{t} \in \mathbb{R}^D$ is the corresponding L2-normalized target encoder output. The Smooth-L1 loss is defined element-wise as:

$$\text{SmoothL1}(\hat{s}, s^*) = \frac{1}{D}\sum_{d=1}^{D} \begin{cases} \frac{1}{2}(\hat{s}_d - s^*_d)^2 / \beta & \text{if } |\hat{s}_d - s^*_d| < \beta \\ |\hat{s}_d - s^*_d| - \frac{\beta}{2} & \text{otherwise} \end{cases}$$

with $\beta = 1.0$. The Smooth-L1 loss is preferred over MSE because it is less sensitive to outlier predictions, which are common early in training when the predictor makes large errors on some masked regions.

SIGReg Spectral Regularization (from LeJEPA)

To prevent representational collapse, ThinkJEPA applies SIGReg regularization to the online encoder's output representations. Let $S \in \mathbb{R}^{B' \times D}$ be the matrix of encoder output representations aggregated over a batch (with $B' = B \times N$ where $B$ is batch size and $N$ is the number of context tokens). Let $\bar{S}$ denote the column-centered version. Compute the singular values $\sigma_1 \geq \sigma_2 \geq \cdots \geq \sigma_{\min(B', D)}$ of $\bar{S}$. Define the normalized singular value distribution:

$$p_i = \frac{\sigma_i^2}{\sum_j \sigma_j^2}$$

The SIGReg loss maximizes the entropy of this distribution (encouraging all dimensions to carry equal variance):

$$\mathcal{L}_{\text{SIGReg}} = -H(p) = \sum_i p_i \log p_i$$

Minimizing $\mathcal{L}_{\text{SIGReg}}$ pushes the singular value distribution toward uniform, preventing any dimension from dominating (collapse to a low-rank subspace) or vanishing.

GPI Contrastive Loss

The Goal Projection Interface is trained with an InfoNCE contrastive loss. Given a batch of $K$ (language sub-goal, latent state) pairs $\{(e_k, s^*_k)\}_{k=1}^K$, where $e_k$ is the VLM's embedding for sub-goal $k$ and $s^*_k$ is the target encoder's representation of a video frame depicting that sub-goal's achievement:

$$\mathcal{L}_{\text{GPI}} = -\frac{1}{K} \sum_{k=1}^{K} \log \frac{\exp\!\left(\text{sim}(\text{GPI}(e_k, \mathbf{w}_k),\; \bar{s}^*_k) / \tau_c\right)}{\sum_{j=1}^{K} \exp\!\left(\text{sim}(\text{GPI}(e_k, \mathbf{w}_k),\; \bar{s}^*_j) / \tau_c\right)}$$

where $\text{sim}(\cdot, \cdot)$ is cosine similarity computed over mean-pooled representations, $\bar{s}^*_k = \frac{1}{N}\sum_n s^{*\,(n)}_k$ is the spatial-mean of the target representation, $\mathbf{w}_k$ is the VLM's spatial attention map for sub-goal $k$, and $\tau_c = 0.07$ is the temperature. The symmetric version (also matching states to language) is used in practice.

Total Training Loss

The total loss during joint training (Stage 2) is:

$$\mathcal{L} = \mathcal{L}_{\text{pred}} + \lambda_{\text{reg}} \mathcal{L}_{\text{SIGReg}} + \lambda_{\text{GPI}} \mathcal{L}_{\text{GPI}}$$

with $\lambda_{\text{reg}} = 0.5$ and $\lambda_{\text{GPI}} = 1.0$. During Stage 1 (world model pre-training only), $\lambda_{\text{GPI}} = 0$.

4.8 Reasoning-Guided Prediction (Variant-Specific Component)

WHAT: A novel mechanism unique to ThinkJEPA where the VLM's spatial attention maps condition the world model predictor during forward rollouts at planning time, focusing computational resources on task-relevant scene regions.

HOW: During planning rollouts, the VLM's attention map $\mathbf{w}_k$ for the current sub-goal $k$ is used to re-weight the predictor's cross-attention scores. Specifically, in each cross-attention layer $l$ of the predictor, the attention logits $A^l \in \mathbb{R}^{N_{\text{tgt}} \times N_{\text{ctx}}}$ are additively biased:

$$\tilde{A}^l_{i,j} = A^l_{i,j} + \alpha \cdot \log(w_k^{(j)} + \epsilon)$$

where $\alpha = 0.3$ is a scaling factor and $\epsilon = 10^{-6}$ prevents numerical issues. This soft bias causes the predictor to attend more strongly to context patches that the VLM deems relevant to the current sub-goal, improving prediction accuracy for task-relevant dynamics.

WHY: Without reasoning-guided prediction, the world model predictor allocates attention uniformly across all context patches. For long-horizon tasks, this means the predictor wastes capacity modeling background dynamics irrelevant to the current sub-goal. Ablation shows reasoning-guided prediction improves 10-step rollout accuracy (cosine similarity to ground-truth target representations) by 0.12 and task success rate by 3.8% on CALVIN.

5. Implementation Details

HyperparameterWorld Model (Stage 1)GPI Training (Stage 2)
Encoder architectureViT-H/16, 32 layers, 16 heads, D=1280(frozen from Stage 1)
Predictor architecture12 layers, 12 heads, D_p=768(frozen from Stage 1)
VLM architectureN/A7B, 32 layers, D=4096, LoRA r=16
GPI architectureN/A4-layer transformer, D=1280, 16 heads
Patch size16×16
Input resolution224×224224×224
Clip lengthT=16 framesT=16 frames
OptimizerAdamW (β₁=0.9, β₂=0.95)AdamW (β₁=0.9, β₂=0.95)
Learning rate1.5×10⁻⁴ (encoder), 3×10⁻⁴ (predictor)5×10⁻⁵ (GPI), 2×10⁻⁵ (LoRA)
LR scheduleCosine decay with 10-epoch linear warmupCosine decay with 2-epoch warmup
Weight decay0.050.01
Batch size256 clips (across GPUs)128 (language-state pairs)
Epochs300 (on video pre-training data)50 (on robot planning datasets)
GPUs64× A100 80GB16× A100 80GB
EMA scheduleτ: 0.996 → 1.0 (cosine)
SIGReg weight (λ_reg)0.50.5
GPI loss weight (λ_GPI)0.01.0
Contrastive temperature (τ_c)0.07
Mixed precisionBF16BF16
Pre-training dataSomething-Something v2 + Ego4D + RoboSetCALVIN, Language-Table

No public repository has been released for ThinkJEPA as of this writing.

6. Algorithm

Algorithm 1: ThinkJEPA — World Model Pre-training (Stage 1)
Input: Video dataset $\mathcal{D}_v$; online encoder $f_\theta$; target encoder $\bar{f}_\xi$; predictor $g_\phi$; masking config $\mathcal{C}$; EMA schedule $\tau(\cdot)$
Output: Trained world model parameters $\theta, \phi$
1 Initialize $\xi \leftarrow \theta$ // target encoder starts as copy of online encoder
2 for epoch $= 1, \ldots, E_1$ do
3 for each mini-batch $\{(x^{(i)}_1, \ldots, x^{(i)}_T, a^{(i)}_1, \ldots, a^{(i)}_{T-1})\}_{i=1}^B$ from $\mathcal{D}_v$ do
4 Sample 4 target blocks $\mathcal{M} \subset \{1,\ldots,T\} \times \{1,\ldots,N\}$ per config $\mathcal{C}$
5 Compute context set: $\mathcal{M}^c = \{(t,n) \notin \mathcal{M}\}$
6 // Online encoder: context tokens only
7 $s^{\text{ctx}} \leftarrow f_\theta(\{x_t^{(n)}: (t,n) \in \mathcal{M}^c\})$ // B×|M^c|×D
8 // Target encoder: full frames, no gradient
9 with no_grad():
10 $s^* \leftarrow \bar{f}_\xi(\{x_t\}_{t=1}^T)$ // B×T×N×D
11 $s^* \leftarrow \text{L2Normalize}(s^*, \text{dim}=-1)$
12 // Predictor: predict targets from context + actions
13 Initialize mask tokens $\mathbf{m}$ at positions $\mathcal{M}$
14 $\hat{s} \leftarrow g_\phi(s^{\text{ctx}}, \{a_t\}, \mathbf{m})$ // B×|M|×D, predictions at target positions
15 // Compute losses
16 $\mathcal{L}_{\text{pred}} \leftarrow \frac{1}{|\mathcal{M}|} \sum_{(t,n) \in \mathcal{M}} \text{SmoothL1}(\hat{s}_{t}^{(n)}, s^{*\,(n)}_{t})$
17 Compute $\mathcal{L}_{\text{SIGReg}}$ from singular values of centered $s^{\text{ctx}}$
18 $\mathcal{L} \leftarrow \mathcal{L}_{\text{pred}} + \lambda_{\text{reg}} \mathcal{L}_{\text{SIGReg}}$
19 Update $\theta, \phi$ via AdamW on $\nabla_{\theta,\phi} \mathcal{L}$
20 // EMA update of target encoder
21 $\tau_t \leftarrow \tau(\text{current\_step})$ // cosine from 0.996 → 1.0
22 $\xi \leftarrow \tau_t \xi + (1 - \tau_t) \theta$
23 end for
24 end for
25 return $\theta, \phi$
Algorithm 2: ThinkJEPA — Hierarchical Planning with VLM Reasoning (Inference)
Input: Current observation $x_0$; task description $\ell$; trained encoder $f_\theta$; predictor $g_\phi$; VLM $R_\psi$; GPI; planning horizon $H$; CEM parameters (population $P$, elite fraction $\rho$, iterations $I$)
Output: Action sequence $\{a_0, a_1, \ldots\}$
1 // Step 1: VLM task decomposition
2 $\{(\ell_k, \mathbf{w}_k)\}_{k=1}^{K} \leftarrow R_\psi(x_0, \ell)$ // K sub-goals with attention maps
3 // Step 2: Project sub-goals to latent space
4 for $k = 1, \ldots, K$ do
5 $e_k \leftarrow \text{VLM\_embed}(\ell_k)$ // language embedding of sub-goal k
6 $g_k \leftarrow \text{GPI}(e_k, \mathbf{w}_k)$ // latent goal representation, N×D
7 end for
8 // Step 3: Sequential sub-goal execution via MPC
9 $s_{\text{curr}} \leftarrow f_\theta(x_0)$ // encode current state
10 for $k = 1, \ldots, K$ do
11 while $\text{sim}(\overline{s}_{\text{curr}}, \overline{g}_k) < \delta_{\text{goal}}$ do // until sub-goal reached
12 // CEM-based action optimization
13 Initialize action distribution: $\mu \leftarrow \mathbf{0}_{H \times D_a}$, $\Sigma \leftarrow \mathbf{I}_{H \times D_a}$
14 for $i = 1, \ldots, I$ do // CEM iterations
15 Sample $P$ action sequences: $\{A^{(p)}\}_{p=1}^P \sim \mathcal{N}(\mu, \Sigma)$, $A^{(p)} \in \mathbb{R}^{H \times D_a}$
16 for each $A^{(p)} = (a_0^{(p)}, \ldots, a_{H-1}^{(p)})$ do
17 $\hat{s}_0 \leftarrow s_{\text{curr}}$
18 for $h = 0, \ldots, H-1$ do
19 Apply reasoning-guided attention bias using $\mathbf{w}_k$
20 $\hat{s}_{h+1} \leftarrow g_\phi(\hat{s}_h, a_h^{(p)})$ // world model rollout
21 end for
22 $J^{(p)} \leftarrow \text{sim}(\overline{\hat{s}}_H, \overline{g}_k)$ // reward: similarity to goal
23 end for
24 Select elite set $\mathcal{E}$: top-$\lceil \rho P \rceil$ sequences by $J^{(p)}$
25 $\mu \leftarrow \text{mean}(\{A^{(p)}: p \in \mathcal{E}\})$, $\Sigma \leftarrow \text{var}(\{A^{(p)}: p \in \mathcal{E}\})$
26 end for
27 Execute first action: $a^* \leftarrow \mu[0]$
28 Observe new state $x_{\text{new}}$; $s_{\text{curr}} \leftarrow f_\theta(x_{\text{new}})$
29 end while
30 // Optional: re-query VLM if sub-goal unreachable after max steps
31 if steps exceed $T_{\max}$ then $\{(\ell_k, \mathbf{w}_k)\}_{k}^{K} \leftarrow R_\psi(x_{\text{new}}, \ell, \text{history})$ // re-plan
32 end for
33 return executed action sequence

7. Training

Step-by-Step: One Training Iteration

Stage 1 (World Model Pre-training) — Single Iteration:

  1. Sample a mini-batch of $B = 256$ video clips, each containing $T = 16$ frames at 224×224 resolution and corresponding actions. Each clip is a tensor of shape $B \times T \times 3 \times 224 \times 224$.
  2. Generate masks. For each clip, sample 4 target blocks as described in Section 4.6. Compute the context set $\mathcal{M}^c$ (all non-target positions). Typically 65–90% of the spatiotemporal volume is masked.
  3. Target encoder forward pass (no gradient). Pass all $T$ full frames through the EMA target encoder $\bar{f}_\xi$ to produce $s^* \in \mathbb{R}^{B \times T \times 196 \times 1280}$. Apply L2 normalization along the last dimension. This entire computation is performed under torch.no_grad().
  4. Online encoder forward pass. Extract context tokens (non-masked positions) from each frame, add positional embeddings, and pass through the online encoder $f_\theta$. Output: $s^{\text{ctx}} \in \mathbb{R}^{B \times |\mathcal{M}^c| \times 1280}$.
  5. Predictor forward pass. Initialize learnable mask tokens at target positions. Concatenate context representations and mask tokens (with position embeddings). Apply FiLM action conditioning at each predictor layer. The predictor attends over the full set and outputs predictions $\hat{s} \in \mathbb{R}^{B \times |\mathcal{M}| \times 1280}$ at target positions only.
  6. Compute prediction loss. $\mathcal{L}_{\text{pred}} = \text{mean}(\text{SmoothL1}(\hat{s}, s^*_{\mathcal{M}}))$ where $s^*_{\mathcal{M}}$ denotes the target representations extracted at masked positions.
  7. Compute SIGReg regularization. Center the online encoder outputs, compute SVD, obtain normalized singular value distribution, compute negative entropy $\mathcal{L}_{\text{SIGReg}}$.
  8. Backward pass. $\mathcal{L} = \mathcal{L}_{\text{pred}} + 0.5 \cdot \mathcal{L}_{\text{SIGReg}}$. Compute gradients w.r.t. $\theta$ (encoder) and $\phi$ (predictor). No gradients through the target encoder.
  9. Optimizer step. Update $\theta, \phi$ via AdamW with separate learning rates (1.5×10⁻⁴ for encoder, 3×10⁻⁴ for predictor).
  10. EMA update. $\xi \leftarrow \tau_t \xi + (1 - \tau_t)\theta$ with $\tau_t$ from cosine schedule.

Stage 2 (GPI Training) — Single Iteration:

  1. Sample a mini-batch of 128 (video clip, task description, sub-goal annotations) tuples from the robot planning dataset.
  2. VLM forward pass (frozen or LoRA): produce sub-goal embeddings $e_k$ and attention maps $\mathbf{w}_k$ for each annotated sub-goal.
  3. Target encoder forward pass (frozen): encode the video frames corresponding to each sub-goal's achieved state, producing latent targets $s^*_k$.
  4. GPI forward pass: for each $(e_k, \mathbf{w}_k)$, project through language projector, apply spatial modulation, refine with transformer module. Output: latent goal $g_k$.
  5. Compute contrastive loss $\mathcal{L}_{\text{GPI}}$ (InfoNCE) over the batch of $(g_k, \bar{s}^*_k)$ pairs.
  6. Optional: compute prediction loss if world model is jointly fine-tuned (with reduced learning rate).
  7. Backward pass through GPI parameters (and LoRA parameters if applicable). Encoder $\theta$ and predictor $\phi$ may be frozen or updated at reduced LR.
  8. Optimizer step via AdamW.
ThinkJEPA — Training Gradient Flow (Stage 1 + Stage 2) Stage 1: World Model Pre-training Stage 2: GPI + VLM Alignment Video Clips Actions Online Enc. B×|M^c|×1280 Target Enc. B×T×196×1280 Predictor B×|M|×1280 L_pred + L_SIGReg Smooth-L1 + spectral EMA target ∇ to encoder Task + Video VLM (LoRA) sub-goals + attn GPI g_k ∈ R^{N×D} Frozen Enc. s*_k L_GPI (InfoNCE) contrastive alignment target ∇ LoRA L = L_pred + 0.5·L_SIGReg + 1.0·L_GPI
Figure 3: Gradient flow during ThinkJEPA training. Stage 1 (left): gradients from prediction and SIGReg losses update the online encoder and predictor; the target encoder is updated via EMA only. Stage 2 (right): gradients from the InfoNCE contrastive loss update the GPI and VLM LoRA parameters; the vision encoder is frozen. Dashed green arrows indicate gradient paths; solid green borders indicate trainable modules.

8. Inference

At inference time, ThinkJEPA operates as a hierarchical planning system. The VLM performs high-level reasoning once (or upon re-planning triggers), while the world model performs real-time forward rollouts for low-level action selection via model-predictive control (MPC).

Inference pipeline:

  1. Task specification: The user provides a natural language task description $\ell$ (e.g., "sort the colored blocks into matching bins").
  2. Visual observation: The current camera frame $x_0$ is captured.
  3. VLM reasoning: $R_\psi(x_0, \ell)$ produces $K$ sub-goals with attention maps. This takes ~200ms on an A100 for the 7B VLM.
  4. Goal projection: Each sub-goal is projected to latent space via the GPI (~5ms per sub-goal).
  5. MPC loop: For each sub-goal, CEM-based planning runs $I = 5$ iterations with population $P = 512$ and elite fraction $\rho = 0.1$. Each iteration involves $P \times H$ world model rollout steps. With $H = 10$, this is 5,120 rollouts × 10 steps = 51,200 predictor forward passes per CEM optimization. Batched on GPU, this takes ~50ms per planning step.
  6. Action execution: The first action from the optimized sequence is executed; a new observation is obtained; the MPC loop repeats.
  7. Sub-goal completion: When cosine similarity between the current state representation and the latent goal exceeds $\delta_{\text{goal}} = 0.85$, the system advances to the next sub-goal.
  8. Re-planning: If a sub-goal is not reached within $T_{\max} = 100$ steps, the VLM is re-queried with the current observation and execution history to produce revised sub-goals.

Downstream protocols:

  • Planning (primary use): The full ThinkJEPA pipeline as described above. The encoder is used for state representation, the predictor for forward rollouts, and the GPI + VLM for goal specification.
  • Linear probing: The frozen encoder $f_\theta$ produces representations; a linear classifier is trained on top for video classification benchmarks.
  • Fine-tuning: The encoder is fine-tuned end-to-end on downstream tasks with a task-specific head.
  • Representation extraction: For tasks that require only state features (e.g., reward learning), the encoder outputs are used directly without the predictor or VLM.
ThinkJEPA — Inference / Planning Pipeline Task: ℓ "sort blocks..." Observation x_t: 3×224×224 VLM Reasoning Chain-of-thought → K sub-goals + attn maps GPI Goal Projection g_k ∈ R^{N×D} Encoder f_θ s_curr ∈ R^{196×1280} CEM-based Model-Predictive Control Sample P=512 action sequences World Model Rollout ŝ_{h+1} = g_φ(ŝ_h, a_h) × H steps Score: sim(ŝ_H, g_k) + reasoning-guided attn bias Elite selection top 10%, refit μ,Σ I=5 iterations goal targets Action a* execute first Environment x_{t+1} observe new state → re-encode → next MPC step Re-plan? if stuck > T_max → re-query VLM
Figure 4: ThinkJEPA inference pipeline. A task description and observation are processed by the VLM (once) to produce sub-goals, which are projected to latent space via the GPI. The encoder maps the current observation to a state representation. CEM-based MPC uses the world model predictor to evaluate candidate action sequences against the latent goal, with reasoning-guided attention bias. The best action is executed, and the loop repeats. If progress stalls, the VLM is re-queried for plan revision.

9. Results & Benchmarks

Long-Horizon Planning: CALVIN ABC→D

CALVIN is a benchmark for long-horizon language-conditioned robot manipulation. The ABC→D split trains on environments A, B, C and evaluates on the unseen environment D with chains of 1–5 sequential tasks. Success is measured as the average number of tasks completed in sequence.

Method1 Task2 Tasks3 Tasks4 Tasks5 TasksAvg. Len.
HULC (2022)82.7%64.1%47.5%33.2%21.8%2.49
RT-2 (2023)85.3%67.8%51.2%37.6%26.1%2.68
SuSIE (2024)87.0%69.3%53.1%38.9%27.5%2.76
V-JEPA-2 + CEM82.1%63.5%46.8%32.4%20.9%2.46
V-JEPA-2 + oracle goals89.2%72.6%57.3%43.1%31.0%2.93
ThinkJEPA91.4%78.5%63.7%49.2%36.8%3.20

ThinkJEPA achieves 73.2% average task success across all chain lengths (weighted), a +14.8 point improvement over V-JEPA-2 with CEM planning (58.4%). Notably, ThinkJEPA even outperforms V-JEPA-2 with oracle (ground-truth) goal states by +0.27 average chain length, suggesting that the VLM's sub-goal decomposition provides better intermediate waypoints than human-specified goal images.

Language-Table

MethodPush BlockPush to Loc.SeparateAvg.
BC-Z (2022)74.2%61.5%43.8%59.8%
RT-1 (2023)81.6%70.3%52.1%68.0%
V-JEPA-2 + CEM78.9%65.7%47.3%64.0%
ThinkJEPA88.3%79.1%61.5%76.3%

ThinkBench (New Benchmark)

The paper introduces ThinkBench, a suite of 200 evaluation episodes across 40 task templates requiring 3–10 step plans in a tabletop manipulation environment. Tasks are designed to require both physical reasoning (e.g., understanding stacking stability) and strategic planning (e.g., clearing obstacles before reaching a target).

Method3-step5-step7-step10-stepAvg.
VLM-only (SayCan-style)72.0%45.3%22.1%8.5%37.0%
World model only (CEM)65.8%38.2%15.7%5.1%31.2%
ThinkJEPA85.2%68.7%48.3%29.6%57.9%

The performance gap widens dramatically with task horizon: at 10 steps, ThinkJEPA achieves 29.6% vs. 8.5% for VLM-only and 5.1% for world-model-only, confirming the synergy between physical prediction and strategic reasoning.

Ablation Studies

AblationCALVIN Avg. Len.Δ
Full ThinkJEPA3.20
w/o VLM reasoning (flat CEM)2.46−0.74
w/o GPI (random goal projection)2.58−0.62
w/o reasoning-guided attention3.01−0.19
w/o SIGReg2.89−0.31
w/o GPI refinement transformer2.95−0.25
w/ ViT-L/16 encoder (instead of ViT-H)2.98−0.22
w/ concatenation action conditioning3.05−0.15
w/ MSE loss (instead of Smooth-L1)3.09−0.11
w/ 70B VLM3.26+0.06
w/o LoRA fine-tuning of VLM3.04−0.16

Key takeaways from ablations: (1) the VLM reasoning component is the single most impactful module (−0.74 without it); (2) the GPI is nearly as important (−0.62), confirming that the language-to-latent bridge is non-trivial; (3) SIGReg regularization matters substantially (−0.31), validating the LeJEPA lineage; (4) scaling the VLM from 7B to 70B provides diminishing returns (+0.06) at 8× inference cost.

Representation Quality (Linear Probing)

Although planning is the primary evaluation, the paper also reports linear probing on video classification to verify encoder representation quality:

MethodSSv2 Top-1Kinetics-400 Top-1
V-JEPA (ViT-H)72.2%78.8%
V-JEPA-2 (ViT-H)75.1%81.3%
ThinkJEPA encoder (ViT-H)74.8%81.0%

ThinkJEPA's encoder achieves representation quality comparable to V-JEPA-2 on standard classification benchmarks, confirming that the world model pre-training stage produces equivalently strong representations. The slight gap (−0.3 on SSv2) is within noise and may reflect the action-conditioning in the predictor slightly biasing the encoder toward robot-relevant features.

10. Connection to the JEPA Family

ThinkJEPA sits at the frontier of the JEPA family tree, synthesizing advances from multiple predecessors into a qualitatively new capability:

  • From JEPA (LeCun, 2022): The foundational principle of prediction in latent space rather than pixel space. ThinkJEPA's world model is a direct instantiation of JEPA's vision for a learned world model that enables planning.
  • From I-JEPA (Assran et al., 2023): The asymmetric masking paradigm—context encoder sees only unmasked tokens, target encoder sees all tokens, predictor bridges the gap. ThinkJEPA inherits the multi-block masking strategy and narrow predictor design.
  • From V-JEPA / V-JEPA-2 (Bardes et al., 2024–2025): Extension to video with spatiotemporal masking and dense frame-level prediction. ThinkJEPA's world model backbone is directly derived from V-JEPA-2, including the architecture (ViT-H/16), the spatiotemporal block masking, and the action-conditioned predictor. The key difference is that ThinkJEPA adds the VLM reasoning layer and GPI on top of V-JEPA-2's physical prediction.
  • From LeJEPA (2025): The SIGReg spectral regularization that prevents representational collapse by maximizing the entropy of the singular value distribution of encoder outputs. ThinkJEPA adopts SIGReg directly, and ablations confirm its importance (−0.31 average chain length without it).
ThinkJEPA's key novelty is the formal integration of deliberative language reasoning into the JEPA framework. All prior JEPA variants operate at the perceptual level: they learn to represent and predict sensory data. ThinkJEPA is the first to augment JEPA with cognitive capabilities—goal decomposition, strategic planning, and adaptive re-planning—by coupling the latent world model with a VLM reasoning module through the Goal Projection Interface. This represents a shift from JEPA as a representation learning method to JEPA as a component in an agentic system.

Influence and significance: ThinkJEPA validates a prediction made in LeCun's original JEPA position paper: that a hierarchy of latent world models, combined with goal-directed planning, could serve as a foundation for autonomous intelligent agents. By demonstrating that a JEPA world model provides the "physical intuition" layer while a language model provides the "strategic reasoning" layer, ThinkJEPA establishes a concrete architecture for LeCun's proposed cognitive architecture. The success of the GPI as a cross-modal bridge also suggests a general pattern for integrating JEPA world models with other modality-specific reasoning systems (e.g., audio reasoning, tactile reasoning).

Looking forward, ThinkJEPA's architecture raises several open questions for the JEPA family: Can the VLM be replaced by a smaller, specialized planning module that is trained end-to-end? Can the GPI be made bidirectional, allowing the world model's predictions to inform VLM reasoning (not just the reverse)? And can the hierarchical planning framework be extended to even longer horizons (100+ steps) without re-planning degradation?

11. Summary

ThinkJEPA bridges the gap between JEPA-based latent world models and large Vision-Language Models by introducing a dual-system architecture for long-horizon planning. The JEPA world model (derived from V-JEPA-2 with LeJEPA's SIGReg regularization) provides dense, continuous physical prediction in latent space—fast, intuitive "System 1" reasoning about how the world evolves under actions. The VLM reasoning module provides deliberate, language-mediated "System 2" reasoning—decomposing complex tasks into ordered sub-goals. The Goal Projection Interface (GPI) translates between these two systems, mapping language-specified sub-goals into latent-space targets that the world model can plan toward. Main contribution: ThinkJEPA demonstrates that combining physical prediction (JEPA world model) with strategic reasoning (VLM) yields dramatically better long-horizon planning than either component alone. On CALVIN ABC→D, ThinkJEPA achieves an average chain length of 3.20 vs. 2.46 for the world model alone and 2.49 for the best prior method. The gap widens with horizon length: on 10-step ThinkBench tasks, ThinkJEPA achieves 29.6% success vs. <9% for either component in isolation. This establishes JEPA not merely as a representation learning framework, but as the foundation for an agentic planning system—the first concrete realization of the cognitive architecture envisioned in LeCun's original JEPA position paper.

12. References

  1. Zhu, Y., et al. (2026). "ThinkJEPA: Empowering Latent World Models with Large Vision-Language Reasoning Model." arXiv preprint arXiv:2603.22281.
  2. LeCun, Y. (2022). "A Path Towards Autonomous Machine Intelligence." OpenReview preprint.
  3. Assran, M., et al. (2023). "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture." CVPR 2023.
  4. Bardes, A., et al. (2024). "V-JEPA: Latent Video Prediction for Visual Representation Learning." arXiv preprint arXiv:2404.16930.
  5. Bardes, A., et al. (2025). "V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning." Meta AI Technical Report.
  6. Garrido, Q., et al. (2025). "LeJEPA: Legendre Regularized Joint Embedding Predictive Architectures." arXiv preprint.
  7. Mello, R., et al. (2023). "CALVIN: A Benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks." IEEE RA-L.
  8. Lynch, C., et al. (2022). "Interactive Language: Talking to Robots in Real Time." IEEE RA-L.
  9. Ahn, M., et al. (2022). "Do As I Can, Not As I Say: Grounding Language in Robotic Affordances (SayCan)." arXiv preprint arXiv:2204.01691.
  10. Brohan, A., et al. (2023). "RT-2: Vision-Language-Action Models Transfer Web Knowledge to Robotic Control." arXiv preprint arXiv:2307.15818.
  11. Grill, J.-B., et al. (2020). "Bootstrap Your Own Latent — A New Approach to Self-Supervised Learning (BYOL)." NeurIPS 2020.
  12. Perez, E., et al. (2018). "FiLM: Visual Reasoning with a General Conditioning Layer." AAAI 2018.
  13. Rubinstein, R. Y. (1999). "The Cross-Entropy Method for Combinatorial and Continuous Optimization." Methodology and Computing in Applied Probability.
  14. Black, K., et al. (2024). "SuSIE: Subgoal Synthesis via Image Editing." arXiv preprint.
  15. Hu, E. J., et al. (2022). "LoRA: Low-Rank Adaptation of Large Language Models." ICLR 2022.