At a glance
ProblemLLMs trained purely with next-token prediction optimize surface token statistics and underexploit abstract, view-invariant semantic structure that JEPA captures in vision.
Key ideaAugment autoregressive training with a joint-embedding predictive loss over the model's own hidden states across paired views of the same content.
ModalityText
Target / maskingTwo views of the same content; one view's hidden embedding is the (stop-gradient) target predicted from the other.
Builds onI-JEPA's latent-prediction principle, combined with standard LLM cross-entropy.
Used forImproving LLM representation quality and downstream performance while preserving generation.

Motivation

Large language models are trained almost entirely with token-level next-token prediction, which optimizes the statistics of surface text. In vision, JEPA-style objectives showed that predicting in an abstract embedding space captures view-invariant semantic structure that pixel- or token-level objectives miss. LLM-JEPA asks whether language modeling underexploits the same kind of abstraction: two phrasings of one idea, or a problem and its solution, are semantically related in a way that token prediction does not directly reward. The aim is to inject an abstraction pressure into LLMs so their representations become predictive of related content, not merely of the next token.

How it works

Texttokens · paired-viewContext encoderf_θTarget encoderf̄_θ · EMAPredictorg_φlatent loss‖ẑ − sg(z̄)‖²z_ctxz̄ (sg)EMA copylocal loss (e.g. MLM)
Canonical JEPA schematic for Text. The input is split into a visible context and hidden targets (token-level, paired-view). The context encoder $f_\theta$ embeds what is visible; the target encoder $\bar f_\theta$ (an EMA copy, gradient stopped) embeds the targets; the predictor $g_\phi$ maps context to the target embeddings; training minimises the latent distance. A local/generative loss runs alongside latent prediction (hybrid objective).

LLM-JEPA keeps standard autoregressive training and adds a latent-prediction term over the model's own hidden representations.

  • Two views of the same content are constructed — for example a problem statement and its solution, or paraphrased inputs.
  • The LLM encodes both views into hidden embeddings.
  • A predictor maps the embedding of one view to the embedding of the other, with the target view's embedding detached (stop-gradient).

This runs alongside the usual cross-entropy token-prediction objective, so the model simultaneously learns to generate tokens and to organize semantically equivalent inputs into mutually predictable embeddings. The generative loss is retained, making the embedding-prediction term complementary rather than a replacement.

The objective

The total loss combines the standard cross-entropy token loss with the latent prediction term:

$$\mathcal{L} = \mathcal{L}_{\text{CE}} + \lambda\,\big\lVert\, g_\phi(z_{\text{ctx}}) - \operatorname{sg}(z_{\text{tgt}})\,\big\rVert_2^2$$

where $z_{\text{ctx}}$ and $z_{\text{tgt}}$ are the LLM's hidden embeddings of the two views, $\operatorname{sg}$ is stop-gradient, and $\lambda$ weights the JEPA term. The token objective ensures generation is preserved while the embedding term pushes related views to be mutually predictable in latent space.

Key results & what's novel

LLM-JEPA transfers the joint-embedding predictive principle from images and audio into language modeling, and is a notable demonstration that JEPA's latent-prediction philosophy complements rather than replaces the generative objective at the core of modern LLMs. By predicting in embedding space across paired views alongside next-token prediction, the model reportedly improves representation quality and downstream performance while retaining its generative ability. The key conceptual contribution is showing that abstraction pressure from embedding-space prediction is beneficial in a domain previously dominated by purely token-level training.

Strengths & limitations

  • + Adds JEPA-style abstraction without sacrificing generation.
  • + Reuses the model's own hidden states; modest architectural overhead.
  • + Improves representations and downstream transfer.
  • Requires constructing paired views of content, which is task-dependent.
  • The weight $\lambda$ and view design must be tuned to avoid harming generation.
  • Benefits depend on the availability of meaningful semantically-equivalent pairs.

Connections & references

Builds onI-JEPA