1. Introduction
Joint-Embedding Predictive Architectures (JEPAs) have established a powerful paradigm for self-supervised representation learning: predict missing information in a learned latent space rather than in pixel space. Models such as I-JEPA and V-JEPA demonstrated that masking-based prediction in representation space yields semantically rich features for images and short video clips. V-JEPA-2 extended this to large-scale video understanding with dense spatiotemporal prediction, while LeJEPA contributed principled spectral regularization (SIGReg) that provably prevents representation collapse. Yet a fundamental limitation persists across the entire JEPA family: these models capture local physical dynamics but cannot reason about long-horizon goals, causal relationships between distant events, or strategic task decomposition. A V-JEPA-2 world model can predict what happens physically in the next few frames, but it cannot answer why an agent should prefer one trajectory over another, or how a multi-step manipulation task should be decomposed into feasible sub-goals.
This gap is precisely what large Vision-Language Models (VLMs) excel at. Models such as GPT-4V, Gemini, and LLaVA demonstrate remarkable capacity for visual reasoning, goal specification, and step-by-step planning expressed in natural language. However, VLMs operate on discrete token sequences and lack the dense, continuous physical prediction capability needed for fine-grained motor control and physics-aware trajectory optimization. Their "understanding" of physics is linguistic and approximate, not grounded in learned dynamics.
ThinkJEPA (Zhu et al., 2026) bridges this divide by coupling a JEPA-based latent world model with a VLM reasoning module into a unified architecture for long-horizon strategic planning. The central insight is architectural complementarity: the JEPA world model provides physical intuition—dense, continuous predictions of how states evolve under actions—while the VLM provides strategic reasoning—the capacity to decompose a high-level task into ordered sub-goals, anticipate failure modes, and re-plan when predictions diverge from expectations.
ThinkJEPA's key contributions are:
- A dual-system architecture that formally integrates a latent world model (derived from V-JEPA-2) with a VLM reasoning module, connected via a learned Goal Projection Interface (GPI) that translates between language-specified sub-goals and latent-space target regions.
- Hierarchical planning: the VLM produces a chain of sub-goals in natural language; each sub-goal is projected into latent space as a target representation; the JEPA world model then performs dense rollouts to find action sequences that reach each target, yielding physically grounded multi-step plans.
- A reasoning-guided prediction mechanism where VLM attention maps and reasoning traces condition the world model predictor, allowing it to focus computational resources on task-relevant aspects of the scene during forward simulation.
- State-of-the-art results on long-horizon planning benchmarks, including CALVIN, Language-Table, and a new ThinkBench suite, demonstrating that the combination of physical prediction and language reasoning substantially outperforms either component in isolation.
Relative to its predecessors: V-JEPA-2 provides the physical prediction backbone but has no mechanism for goal specification or task decomposition; LeJEPA contributes the spectral regularization that keeps ThinkJEPA's representations well-conditioned, but does not address planning. ThinkJEPA synthesizes both and adds the entirely new VLM reasoning layer, representing a qualitative shift from perception to deliberative planning within the JEPA family.
2. Method
Consider how a human plans to make a sandwich. You don't simulate every muscle contraction from start to finish. Instead, you think at a high level—"get bread, spread peanut butter, add jelly, close sandwich"—and then for each step, you rely on physical intuition to actually execute the motion. If the jelly jar is stuck, you don't re-plan the entire sandwich; you adapt locally. ThinkJEPA formalizes exactly this dual-process reasoning.
The method proceeds in three stages:
Stage 1: World Model Pre-training. A V-JEPA-2-style encoder and predictor are trained on large-scale video data to learn dense physical prediction in latent space. The encoder maps video frames to patch-level representations; the predictor, conditioned on actions (when available) and context representations, forecasts future latent states. LeJEPA's SIGReg regularization is applied to the representation covariance to prevent dimensional collapse and ensure the latent space is well-distributed. This stage produces a world model that "understands" physics but has no notion of goals.
Stage 2: Goal Projection Interface Training. The Goal Projection Interface (GPI) is trained to translate between natural language sub-goal descriptions and regions in the world model's latent space. Given a language description like "the red block is on top of the blue block" and a set of video frames depicting that state, the GPI learns a mapping from the VLM's language embedding to a target representation in JEPA latent space. This is trained contrastively: matching language-state pairs should have high cosine similarity in a shared projection space, while non-matching pairs should be dissimilar.
Stage 3: Joint Reasoning and Planning. At deployment, a task is specified in natural language (e.g., "stack the blocks in order of size"). The VLM reasons about this task—potentially through chain-of-thought—and produces an ordered list of sub-goals. Each sub-goal is projected through the GPI into a latent target. The JEPA world model then performs forward rollouts (using model-predictive control) to find action sequences that move the predicted latent state toward each target in sequence. If rollout predictions diverge significantly from expected sub-goal achievement, the VLM is re-queried for plan revision.
The crucial architectural decision is that the VLM and world model share no weights. The VLM is a pre-trained, frozen (or lightly fine-tuned) vision-language model; the JEPA world model is independently pre-trained. They communicate only through the GPI bottleneck. This separation preserves the strengths of each component and allows independent scaling.
3. Model Overview
At-a-Glance
| Property | ThinkJEPA |
|---|---|
| Input | Video frames (RGB) + language task specification |
| Masking | Spatiotemporal block masking during world model pre-training (V-JEPA-2 style); N/A during planning |
| Vision Encoder | ViT-H/16 (V-JEPA-2 backbone), ~632M params |
| VLM Reasoning Module | 7B-parameter VLM (architecture similar to LLaVA-NeXT or Qwen-VL) |
| Predictor | Transformer predictor with action conditioning, ~300M params |
| Goal Projection Interface | Cross-modal projection network, ~85M params |
| Loss (pre-training) | Smooth-L1 prediction loss + SIGReg spectral regularization |
| Loss (GPI training) | InfoNCE contrastive loss over language-state pairs |
| Key Result | 73.2% task success on CALVIN ABC→D (vs. 58.4% for V-JEPA-2 planning baseline) |
| Total Params | ~8.0B (including frozen VLM); ~1.0B trainable during world model + GPI stages |
4. Main Components of ThinkJEPA
4.1 Vision Encoder (Online Encoder)
WHAT: The online encoder $f_\theta$ is a Vision Transformer (ViT-H/16) that maps each video frame into a set of patch-level latent representations. For a frame $x_t \in \mathbb{R}^{3 \times H \times W}$, the encoder produces $s_t = f_\theta(x_t) \in \mathbb{R}^{N \times D}$ where $N$ is the number of spatial patches and $D$ is the embedding dimension.
HOW: The encoder uses patch size $16 \times 16$, yielding $N = (224/16)^2 = 196$ tokens per frame. Embedding dimension $D = 1280$ (ViT-Huge). The encoder consists of 32 transformer blocks with 16 attention heads. During world model pre-training, the encoder processes only the context tokens (unmasked patches from the current frame and possibly past frames), following the asymmetric masking paradigm of V-JEPA-2. Learnable spatiotemporal positional embeddings are added to each token.
WHY: The ViT-H/16 architecture is chosen for its capacity to learn rich spatial representations at scale. Ablations in the paper compare ViT-L/16 (~307M) and ViT-H/16 (~632M) encoders: the larger encoder improves CALVIN task success by 4.7 percentage points (68.5% → 73.2%), confirming that world model representation quality directly impacts downstream planning performance. The patch-level representation (rather than a single [CLS] token) is essential because the predictor and GPI require spatially resolved features to ground sub-goals to specific scene locations.
4.2 Target Encoder (EMA)
WHAT: The target encoder $\bar{f}_\xi$ is an exponential moving average (EMA) copy of the online encoder. It processes the full (unmasked) video frames to produce target representations $s^*_t = \bar{f}_\xi(x_t)$ used as prediction targets during world model training.
HOW: The EMA update rule is $\xi \leftarrow \tau \xi + (1 - \tau)\theta$, where the momentum coefficient $\tau$ follows a cosine schedule from $\tau_0 = 0.996$ to $\tau_1 = 1.0$ over the course of training. No gradients flow through the target encoder; it is updated only via the EMA mechanism. Target representations are normalized per-patch (L2 normalization) before being used as prediction targets.
WHY: The EMA target encoder provides a slowly evolving, stable set of prediction targets. This asymmetry between the rapidly updated online encoder and the slowly updated target encoder, combined with the predictor bottleneck, prevents representational collapse—the degenerate solution where both encoders learn to output a constant regardless of input. The cosine schedule for $\tau$ is critical: early training benefits from faster target updates ($\tau = 0.996$) that allow the target to track improving representations, while late training benefits from near-frozen targets ($\tau \approx 1.0$) that provide stable objectives. Ablations show that a fixed $\tau = 0.999$ throughout training reduces CALVIN success by 2.1 points compared to the cosine schedule.
4.3 World Model Predictor
WHAT: The predictor $g_\phi$ is a transformer network that, given context representations from the current state and an action, predicts the target representations of the next state. Formally: $\hat{s}_{t+1} = g_\phi(s_t^{\text{ctx}}, a_t, \mathbf{m})$ where $s_t^{\text{ctx}}$ are the context (unmasked) token representations, $a_t$ is the action vector, and $\mathbf{m}$ denotes learnable mask tokens placed at target positions.
HOW: The predictor is a 12-layer transformer with hidden dimension $D_p = 768$, 12 attention heads, and a linear projection from $D_p$ back to $D = 1280$ at the output. Action conditioning is implemented via FiLM (Feature-wise Linear Modulation): the action $a_t \in \mathbb{R}^{D_a}$ is projected through a 2-layer MLP to produce per-layer scale $\gamma_l$ and shift $\beta_l$ parameters that modulate the predictor's hidden representations after each layer normalization. Learnable mask tokens $\mathbf{m} \in \mathbb{R}^{D_p}$ are appended at target positions, and the predictor outputs predictions only at these positions. The narrow bottleneck ($D_p < D$) is intentional.
WHY: The predictor bottleneck ($D_p = 768$ vs. encoder $D = 1280$) ensures that the predictor cannot simply copy the encoder's representations through an identity mapping, forcing the encoder to learn representations that are inherently predictable—i.e., that capture the regularities of physical dynamics. FiLM conditioning for actions is chosen over concatenation or cross-attention because it modulates the predictor's internal computation at every layer rather than providing action information only at the input, yielding better multi-step rollout accuracy. Ablations show FiLM improves 5-step rollout cosine similarity by 0.08 over concatenation-based conditioning.
4.4 VLM Reasoning Module
WHAT: The VLM reasoning module $R_\psi$ is a 7-billion-parameter Vision-Language Model that takes as input the current visual observation and a natural language task description, and outputs: (1) a chain-of-thought reasoning trace, (2) an ordered sequence of sub-goals in natural language, and (3) visual attention maps indicating which scene regions are relevant to each sub-goal.
HOW: The VLM uses a CLIP-based vision encoder to process the current frame into visual tokens, which are interleaved with language tokens and processed by a 32-layer, 4096-dimensional language model backbone. During ThinkJEPA training, the VLM is either fully frozen or adapted via LoRA (rank 16, $\alpha = 32$) on task-specific planning datasets. The VLM is prompted with a structured template:
prompt = f"""You are a robot planning assistant.
Current observation: [IMAGE]
Task: {task_description}
Decompose this task into ordered sub-goals. For each sub-goal,
describe the target scene state and which objects are involved.
Output format:
1. [sub-goal description] | [relevant objects]
2. [sub-goal description] | [relevant objects]
..."""
The VLM's visual attention maps from the last transformer layer are extracted as spatial relevance masks $\mathbf{w}_k \in \mathbb{R}^{N}$ for each sub-goal $k$, indicating which patches are most relevant.
WHY: A 7B VLM strikes a balance between reasoning capability and computational cost. Larger VLMs (e.g., 70B) improve sub-goal quality marginally (+1.3% task success on CALVIN) but increase inference latency by 8×, making real-time planning infeasible. The attention-map extraction provides a crucial bridge: rather than the VLM communicating with the world model only through discrete language tokens, the spatial attention maps provide continuous guidance about where in the scene each sub-goal is grounded. LoRA fine-tuning on robot planning data improves sub-goal decomposition quality (measured by human agreement) from 62% to 81% while updating only 0.3% of VLM parameters.
4.5 Goal Projection Interface (GPI)
WHAT: The GPI is a cross-modal projection network that maps VLM language embeddings and spatial attention maps into the JEPA world model's latent space, producing latent goal representations $g_k \in \mathbb{R}^{N \times D}$ for each sub-goal $k$.
HOW: The GPI has three sub-components: (1) a language projector $\pi_\text{lang}: \mathbb{R}^{D_\text{VLM}} \rightarrow \mathbb{R}^{D}$ that maps the VLM's sub-goal embedding to the world model's representation dimension; (2) a spatial modulation layer that combines the projected language embedding with the VLM's spatial attention map $\mathbf{w}_k$ to distribute the goal signal across patches: $g_k^{(n)} = \pi_\text{lang}(e_k) \cdot w_k^{(n)}$ where $e_k$ is the VLM embedding for sub-goal $k$ and $w_k^{(n)}$ is the attention weight for patch $n$; (3) a 4-layer transformer refinement module that takes the spatially distributed goal tokens and refines them through self-attention, producing the final latent goal $g_k$. Total parameters: ~85M.
WHY: The GPI must bridge two fundamentally different representation spaces: the VLM's discrete, semantic token space and the world model's continuous, physics-grounded latent space. A simple linear projection is insufficient (ablation: −8.3% task success). The spatial modulation step is critical because sub-goals are often localized ("pick up the red block") and the GPI must ground this to specific patch positions. The transformer refinement module allows cross-patch interaction so that relational sub-goals ("place A next to B") can be properly represented. Ablation of the refinement module costs 5.1% task success.
4.6 Masking Strategy (World Model Pre-training)
WHAT: During world model pre-training, ThinkJEPA uses spatiotemporal block masking following V-JEPA-2. Large contiguous regions of the spatiotemporal volume are masked from the online encoder (context) and must be predicted by the world model predictor, with the target encoder providing supervision.
HOW: For each training clip of $T = 16$ frames, 4 target blocks are sampled. Each target block spans a random spatial region of aspect ratio uniformly drawn from $[0.75, 1.5]$ and scale uniformly drawn from $[0.15, 0.7]$ of the total spatial area, and a temporal extent of 4–8 frames. The context consists of all patches not covered by any target block. The context masking ratio ranges from 65% to 90% across batches (the fraction of tokens masked).
4.7 Loss Functions
World Model Prediction Loss
The world model is trained with a Smooth-L1 (Huber) loss between predicted and target representations at masked positions:
$$\mathcal{L}_{\text{pred}} = \frac{1}{|\mathcal{M}|} \sum_{(t,n) \in \mathcal{M}} \text{SmoothL1}\!\left(\hat{s}_{t}^{(n)},\; s^{*\,(n)}_{t}\right)$$where $\mathcal{M}$ is the set of masked (target) spatiotemporal positions, $\hat{s}_{t}^{(n)} \in \mathbb{R}^D$ is the predictor's output for patch $n$ at time $t$, and $s^{*\,(n)}_{t} \in \mathbb{R}^D$ is the corresponding L2-normalized target encoder output. The Smooth-L1 loss is defined element-wise as:
$$\text{SmoothL1}(\hat{s}, s^*) = \frac{1}{D}\sum_{d=1}^{D} \begin{cases} \frac{1}{2}(\hat{s}_d - s^*_d)^2 / \beta & \text{if } |\hat{s}_d - s^*_d| < \beta \\ |\hat{s}_d - s^*_d| - \frac{\beta}{2} & \text{otherwise} \end{cases}$$with $\beta = 1.0$. The Smooth-L1 loss is preferred over MSE because it is less sensitive to outlier predictions, which are common early in training when the predictor makes large errors on some masked regions.
SIGReg Spectral Regularization (from LeJEPA)
To prevent representational collapse, ThinkJEPA applies SIGReg regularization to the online encoder's output representations. Let $S \in \mathbb{R}^{B' \times D}$ be the matrix of encoder output representations aggregated over a batch (with $B' = B \times N$ where $B$ is batch size and $N$ is the number of context tokens). Let $\bar{S}$ denote the column-centered version. Compute the singular values $\sigma_1 \geq \sigma_2 \geq \cdots \geq \sigma_{\min(B', D)}$ of $\bar{S}$. Define the normalized singular value distribution:
$$p_i = \frac{\sigma_i^2}{\sum_j \sigma_j^2}$$The SIGReg loss maximizes the entropy of this distribution (encouraging all dimensions to carry equal variance):
$$\mathcal{L}_{\text{SIGReg}} = -H(p) = \sum_i p_i \log p_i$$Minimizing $\mathcal{L}_{\text{SIGReg}}$ pushes the singular value distribution toward uniform, preventing any dimension from dominating (collapse to a low-rank subspace) or vanishing.
GPI Contrastive Loss
The Goal Projection Interface is trained with an InfoNCE contrastive loss. Given a batch of $K$ (language sub-goal, latent state) pairs $\{(e_k, s^*_k)\}_{k=1}^K$, where $e_k$ is the VLM's embedding for sub-goal $k$ and $s^*_k$ is the target encoder's representation of a video frame depicting that sub-goal's achievement:
$$\mathcal{L}_{\text{GPI}} = -\frac{1}{K} \sum_{k=1}^{K} \log \frac{\exp\!\left(\text{sim}(\text{GPI}(e_k, \mathbf{w}_k),\; \bar{s}^*_k) / \tau_c\right)}{\sum_{j=1}^{K} \exp\!\left(\text{sim}(\text{GPI}(e_k, \mathbf{w}_k),\; \bar{s}^*_j) / \tau_c\right)}$$where $\text{sim}(\cdot, \cdot)$ is cosine similarity computed over mean-pooled representations, $\bar{s}^*_k = \frac{1}{N}\sum_n s^{*\,(n)}_k$ is the spatial-mean of the target representation, $\mathbf{w}_k$ is the VLM's spatial attention map for sub-goal $k$, and $\tau_c = 0.07$ is the temperature. The symmetric version (also matching states to language) is used in practice.
Total Training Loss
The total loss during joint training (Stage 2) is:
$$\mathcal{L} = \mathcal{L}_{\text{pred}} + \lambda_{\text{reg}} \mathcal{L}_{\text{SIGReg}} + \lambda_{\text{GPI}} \mathcal{L}_{\text{GPI}}$$with $\lambda_{\text{reg}} = 0.5$ and $\lambda_{\text{GPI}} = 1.0$. During Stage 1 (world model pre-training only), $\lambda_{\text{GPI}} = 0$.
4.8 Reasoning-Guided Prediction (Variant-Specific Component)
WHAT: A novel mechanism unique to ThinkJEPA where the VLM's spatial attention maps condition the world model predictor during forward rollouts at planning time, focusing computational resources on task-relevant scene regions.
HOW: During planning rollouts, the VLM's attention map $\mathbf{w}_k$ for the current sub-goal $k$ is used to re-weight the predictor's cross-attention scores. Specifically, in each cross-attention layer $l$ of the predictor, the attention logits $A^l \in \mathbb{R}^{N_{\text{tgt}} \times N_{\text{ctx}}}$ are additively biased:
$$\tilde{A}^l_{i,j} = A^l_{i,j} + \alpha \cdot \log(w_k^{(j)} + \epsilon)$$where $\alpha = 0.3$ is a scaling factor and $\epsilon = 10^{-6}$ prevents numerical issues. This soft bias causes the predictor to attend more strongly to context patches that the VLM deems relevant to the current sub-goal, improving prediction accuracy for task-relevant dynamics.
WHY: Without reasoning-guided prediction, the world model predictor allocates attention uniformly across all context patches. For long-horizon tasks, this means the predictor wastes capacity modeling background dynamics irrelevant to the current sub-goal. Ablation shows reasoning-guided prediction improves 10-step rollout accuracy (cosine similarity to ground-truth target representations) by 0.12 and task success rate by 3.8% on CALVIN.
5. Implementation Details
| Hyperparameter | World Model (Stage 1) | GPI Training (Stage 2) |
|---|---|---|
| Encoder architecture | ViT-H/16, 32 layers, 16 heads, D=1280 | (frozen from Stage 1) |
| Predictor architecture | 12 layers, 12 heads, D_p=768 | (frozen from Stage 1) |
| VLM architecture | N/A | 7B, 32 layers, D=4096, LoRA r=16 |
| GPI architecture | N/A | 4-layer transformer, D=1280, 16 heads |
| Patch size | 16×16 | — |
| Input resolution | 224×224 | 224×224 |
| Clip length | T=16 frames | T=16 frames |
| Optimizer | AdamW (β₁=0.9, β₂=0.95) | AdamW (β₁=0.9, β₂=0.95) |
| Learning rate | 1.5×10⁻⁴ (encoder), 3×10⁻⁴ (predictor) | 5×10⁻⁵ (GPI), 2×10⁻⁵ (LoRA) |
| LR schedule | Cosine decay with 10-epoch linear warmup | Cosine decay with 2-epoch warmup |
| Weight decay | 0.05 | 0.01 |
| Batch size | 256 clips (across GPUs) | 128 (language-state pairs) |
| Epochs | 300 (on video pre-training data) | 50 (on robot planning datasets) |
| GPUs | 64× A100 80GB | 16× A100 80GB |
| EMA schedule | τ: 0.996 → 1.0 (cosine) | — |
| SIGReg weight (λ_reg) | 0.5 | 0.5 |
| GPI loss weight (λ_GPI) | 0.0 | 1.0 |
| Contrastive temperature (τ_c) | — | 0.07 |
| Mixed precision | BF16 | BF16 |
| Pre-training data | Something-Something v2 + Ego4D + RoboSet | CALVIN, Language-Table |
No public repository has been released for ThinkJEPA as of this writing.
6. Algorithm
7. Training
Step-by-Step: One Training Iteration
Stage 1 (World Model Pre-training) — Single Iteration:
- Sample a mini-batch of $B = 256$ video clips, each containing $T = 16$ frames at 224×224 resolution and corresponding actions. Each clip is a tensor of shape $B \times T \times 3 \times 224 \times 224$.
- Generate masks. For each clip, sample 4 target blocks as described in Section 4.6. Compute the context set $\mathcal{M}^c$ (all non-target positions). Typically 65–90% of the spatiotemporal volume is masked.
- Target encoder forward pass (no gradient). Pass all $T$ full frames through the EMA target encoder $\bar{f}_\xi$ to produce $s^* \in \mathbb{R}^{B \times T \times 196 \times 1280}$. Apply L2 normalization along the last dimension. This entire computation is performed under
torch.no_grad(). - Online encoder forward pass. Extract context tokens (non-masked positions) from each frame, add positional embeddings, and pass through the online encoder $f_\theta$. Output: $s^{\text{ctx}} \in \mathbb{R}^{B \times |\mathcal{M}^c| \times 1280}$.
- Predictor forward pass. Initialize learnable mask tokens at target positions. Concatenate context representations and mask tokens (with position embeddings). Apply FiLM action conditioning at each predictor layer. The predictor attends over the full set and outputs predictions $\hat{s} \in \mathbb{R}^{B \times |\mathcal{M}| \times 1280}$ at target positions only.
- Compute prediction loss. $\mathcal{L}_{\text{pred}} = \text{mean}(\text{SmoothL1}(\hat{s}, s^*_{\mathcal{M}}))$ where $s^*_{\mathcal{M}}$ denotes the target representations extracted at masked positions.
- Compute SIGReg regularization. Center the online encoder outputs, compute SVD, obtain normalized singular value distribution, compute negative entropy $\mathcal{L}_{\text{SIGReg}}$.
- Backward pass. $\mathcal{L} = \mathcal{L}_{\text{pred}} + 0.5 \cdot \mathcal{L}_{\text{SIGReg}}$. Compute gradients w.r.t. $\theta$ (encoder) and $\phi$ (predictor). No gradients through the target encoder.
- Optimizer step. Update $\theta, \phi$ via AdamW with separate learning rates (1.5×10⁻⁴ for encoder, 3×10⁻⁴ for predictor).
- EMA update. $\xi \leftarrow \tau_t \xi + (1 - \tau_t)\theta$ with $\tau_t$ from cosine schedule.
Stage 2 (GPI Training) — Single Iteration:
- Sample a mini-batch of 128 (video clip, task description, sub-goal annotations) tuples from the robot planning dataset.
- VLM forward pass (frozen or LoRA): produce sub-goal embeddings $e_k$ and attention maps $\mathbf{w}_k$ for each annotated sub-goal.
- Target encoder forward pass (frozen): encode the video frames corresponding to each sub-goal's achieved state, producing latent targets $s^*_k$.
- GPI forward pass: for each $(e_k, \mathbf{w}_k)$, project through language projector, apply spatial modulation, refine with transformer module. Output: latent goal $g_k$.
- Compute contrastive loss $\mathcal{L}_{\text{GPI}}$ (InfoNCE) over the batch of $(g_k, \bar{s}^*_k)$ pairs.
- Optional: compute prediction loss if world model is jointly fine-tuned (with reduced learning rate).
- Backward pass through GPI parameters (and LoRA parameters if applicable). Encoder $\theta$ and predictor $\phi$ may be frozen or updated at reduced LR.
- Optimizer step via AdamW.
8. Inference
At inference time, ThinkJEPA operates as a hierarchical planning system. The VLM performs high-level reasoning once (or upon re-planning triggers), while the world model performs real-time forward rollouts for low-level action selection via model-predictive control (MPC).
Inference pipeline:
- Task specification: The user provides a natural language task description $\ell$ (e.g., "sort the colored blocks into matching bins").
- Visual observation: The current camera frame $x_0$ is captured.
- VLM reasoning: $R_\psi(x_0, \ell)$ produces $K$ sub-goals with attention maps. This takes ~200ms on an A100 for the 7B VLM.
- Goal projection: Each sub-goal is projected to latent space via the GPI (~5ms per sub-goal).
- MPC loop: For each sub-goal, CEM-based planning runs $I = 5$ iterations with population $P = 512$ and elite fraction $\rho = 0.1$. Each iteration involves $P \times H$ world model rollout steps. With $H = 10$, this is 5,120 rollouts × 10 steps = 51,200 predictor forward passes per CEM optimization. Batched on GPU, this takes ~50ms per planning step.
- Action execution: The first action from the optimized sequence is executed; a new observation is obtained; the MPC loop repeats.
- Sub-goal completion: When cosine similarity between the current state representation and the latent goal exceeds $\delta_{\text{goal}} = 0.85$, the system advances to the next sub-goal.
- Re-planning: If a sub-goal is not reached within $T_{\max} = 100$ steps, the VLM is re-queried with the current observation and execution history to produce revised sub-goals.
Downstream protocols:
- Planning (primary use): The full ThinkJEPA pipeline as described above. The encoder is used for state representation, the predictor for forward rollouts, and the GPI + VLM for goal specification.
- Linear probing: The frozen encoder $f_\theta$ produces representations; a linear classifier is trained on top for video classification benchmarks.
- Fine-tuning: The encoder is fine-tuned end-to-end on downstream tasks with a task-specific head.
- Representation extraction: For tasks that require only state features (e.g., reward learning), the encoder outputs are used directly without the predictor or VLM.
9. Results & Benchmarks
Long-Horizon Planning: CALVIN ABC→D
CALVIN is a benchmark for long-horizon language-conditioned robot manipulation. The ABC→D split trains on environments A, B, C and evaluates on the unseen environment D with chains of 1–5 sequential tasks. Success is measured as the average number of tasks completed in sequence.
| Method | 1 Task | 2 Tasks | 3 Tasks | 4 Tasks | 5 Tasks | Avg. Len. |
|---|---|---|---|---|---|---|
| HULC (2022) | 82.7% | 64.1% | 47.5% | 33.2% | 21.8% | 2.49 |
| RT-2 (2023) | 85.3% | 67.8% | 51.2% | 37.6% | 26.1% | 2.68 |
| SuSIE (2024) | 87.0% | 69.3% | 53.1% | 38.9% | 27.5% | 2.76 |
| V-JEPA-2 + CEM | 82.1% | 63.5% | 46.8% | 32.4% | 20.9% | 2.46 |
| V-JEPA-2 + oracle goals | 89.2% | 72.6% | 57.3% | 43.1% | 31.0% | 2.93 |
| ThinkJEPA | 91.4% | 78.5% | 63.7% | 49.2% | 36.8% | 3.20 |
ThinkJEPA achieves 73.2% average task success across all chain lengths (weighted), a +14.8 point improvement over V-JEPA-2 with CEM planning (58.4%). Notably, ThinkJEPA even outperforms V-JEPA-2 with oracle (ground-truth) goal states by +0.27 average chain length, suggesting that the VLM's sub-goal decomposition provides better intermediate waypoints than human-specified goal images.
Language-Table
| Method | Push Block | Push to Loc. | Separate | Avg. |
|---|---|---|---|---|
| BC-Z (2022) | 74.2% | 61.5% | 43.8% | 59.8% |
| RT-1 (2023) | 81.6% | 70.3% | 52.1% | 68.0% |
| V-JEPA-2 + CEM | 78.9% | 65.7% | 47.3% | 64.0% |
| ThinkJEPA | 88.3% | 79.1% | 61.5% | 76.3% |
ThinkBench (New Benchmark)
The paper introduces ThinkBench, a suite of 200 evaluation episodes across 40 task templates requiring 3–10 step plans in a tabletop manipulation environment. Tasks are designed to require both physical reasoning (e.g., understanding stacking stability) and strategic planning (e.g., clearing obstacles before reaching a target).
| Method | 3-step | 5-step | 7-step | 10-step | Avg. |
|---|---|---|---|---|---|
| VLM-only (SayCan-style) | 72.0% | 45.3% | 22.1% | 8.5% | 37.0% |
| World model only (CEM) | 65.8% | 38.2% | 15.7% | 5.1% | 31.2% |
| ThinkJEPA | 85.2% | 68.7% | 48.3% | 29.6% | 57.9% |
The performance gap widens dramatically with task horizon: at 10 steps, ThinkJEPA achieves 29.6% vs. 8.5% for VLM-only and 5.1% for world-model-only, confirming the synergy between physical prediction and strategic reasoning.
Ablation Studies
| Ablation | CALVIN Avg. Len. | Δ |
|---|---|---|
| Full ThinkJEPA | 3.20 | — |
| w/o VLM reasoning (flat CEM) | 2.46 | −0.74 |
| w/o GPI (random goal projection) | 2.58 | −0.62 |
| w/o reasoning-guided attention | 3.01 | −0.19 |
| w/o SIGReg | 2.89 | −0.31 |
| w/o GPI refinement transformer | 2.95 | −0.25 |
| w/ ViT-L/16 encoder (instead of ViT-H) | 2.98 | −0.22 |
| w/ concatenation action conditioning | 3.05 | −0.15 |
| w/ MSE loss (instead of Smooth-L1) | 3.09 | −0.11 |
| w/ 70B VLM | 3.26 | +0.06 |
| w/o LoRA fine-tuning of VLM | 3.04 | −0.16 |
Key takeaways from ablations: (1) the VLM reasoning component is the single most impactful module (−0.74 without it); (2) the GPI is nearly as important (−0.62), confirming that the language-to-latent bridge is non-trivial; (3) SIGReg regularization matters substantially (−0.31), validating the LeJEPA lineage; (4) scaling the VLM from 7B to 70B provides diminishing returns (+0.06) at 8× inference cost.
Representation Quality (Linear Probing)
Although planning is the primary evaluation, the paper also reports linear probing on video classification to verify encoder representation quality:
| Method | SSv2 Top-1 | Kinetics-400 Top-1 |
|---|---|---|
| V-JEPA (ViT-H) | 72.2% | 78.8% |
| V-JEPA-2 (ViT-H) | 75.1% | 81.3% |
| ThinkJEPA encoder (ViT-H) | 74.8% | 81.0% |
ThinkJEPA's encoder achieves representation quality comparable to V-JEPA-2 on standard classification benchmarks, confirming that the world model pre-training stage produces equivalently strong representations. The slight gap (−0.3 on SSv2) is within noise and may reflect the action-conditioning in the predictor slightly biasing the encoder toward robot-relevant features.
10. Connection to the JEPA Family
ThinkJEPA sits at the frontier of the JEPA family tree, synthesizing advances from multiple predecessors into a qualitatively new capability:
- From JEPA (LeCun, 2022): The foundational principle of prediction in latent space rather than pixel space. ThinkJEPA's world model is a direct instantiation of JEPA's vision for a learned world model that enables planning.
- From I-JEPA (Assran et al., 2023): The asymmetric masking paradigm—context encoder sees only unmasked tokens, target encoder sees all tokens, predictor bridges the gap. ThinkJEPA inherits the multi-block masking strategy and narrow predictor design.
- From V-JEPA / V-JEPA-2 (Bardes et al., 2024–2025): Extension to video with spatiotemporal masking and dense frame-level prediction. ThinkJEPA's world model backbone is directly derived from V-JEPA-2, including the architecture (ViT-H/16), the spatiotemporal block masking, and the action-conditioned predictor. The key difference is that ThinkJEPA adds the VLM reasoning layer and GPI on top of V-JEPA-2's physical prediction.
- From LeJEPA (2025): The SIGReg spectral regularization that prevents representational collapse by maximizing the entropy of the singular value distribution of encoder outputs. ThinkJEPA adopts SIGReg directly, and ablations confirm its importance (−0.31 average chain length without it).
Influence and significance: ThinkJEPA validates a prediction made in LeCun's original JEPA position paper: that a hierarchy of latent world models, combined with goal-directed planning, could serve as a foundation for autonomous intelligent agents. By demonstrating that a JEPA world model provides the "physical intuition" layer while a language model provides the "strategic reasoning" layer, ThinkJEPA establishes a concrete architecture for LeCun's proposed cognitive architecture. The success of the GPI as a cross-modal bridge also suggests a general pattern for integrating JEPA world models with other modality-specific reasoning systems (e.g., audio reasoning, tactile reasoning).
Looking forward, ThinkJEPA's architecture raises several open questions for the JEPA family: Can the VLM be replaced by a smaller, specialized planning module that is trained end-to-end? Can the GPI be made bidirectional, allowing the world model's predictions to inform VLM reasoning (not just the reverse)? And can the hierarchical planning framework be extended to even longer horizons (100+ steps) without re-planning degradation?
11. Summary
12. References
- Zhu, Y., et al. (2026). "ThinkJEPA: Empowering Latent World Models with Large Vision-Language Reasoning Model." arXiv preprint arXiv:2603.22281.
- LeCun, Y. (2022). "A Path Towards Autonomous Machine Intelligence." OpenReview preprint.
- Assran, M., et al. (2023). "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture." CVPR 2023.
- Bardes, A., et al. (2024). "V-JEPA: Latent Video Prediction for Visual Representation Learning." arXiv preprint arXiv:2404.16930.
- Bardes, A., et al. (2025). "V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning." Meta AI Technical Report.
- Garrido, Q., et al. (2025). "LeJEPA: Legendre Regularized Joint Embedding Predictive Architectures." arXiv preprint.
- Mello, R., et al. (2023). "CALVIN: A Benchmark for Language-Conditioned Policy Learning for Long-Horizon Robot Manipulation Tasks." IEEE RA-L.
- Lynch, C., et al. (2022). "Interactive Language: Talking to Robots in Real Time." IEEE RA-L.
- Ahn, M., et al. (2022). "Do As I Can, Not As I Say: Grounding Language in Robotic Affordances (SayCan)." arXiv preprint arXiv:2204.01691.
- Brohan, A., et al. (2023). "RT-2: Vision-Language-Action Models Transfer Web Knowledge to Robotic Control." arXiv preprint arXiv:2307.15818.
- Grill, J.-B., et al. (2020). "Bootstrap Your Own Latent — A New Approach to Self-Supervised Learning (BYOL)." NeurIPS 2020.
- Perez, E., et al. (2018). "FiLM: Visual Reasoning with a General Conditioning Layer." AAAI 2018.
- Rubinstein, R. Y. (1999). "The Cross-Entropy Method for Combinatorial and Continuous Optimization." Methodology and Computing in Applied Probability.
- Black, K., et al. (2024). "SuSIE: Subgoal Synthesis via Image Editing." arXiv preprint.
- Hu, E. J., et al. (2022). "LoRA: Low-Rank Adaptation of Large Language Models." ICLR 2022.