At a glance
ProblemSingle-cell RNA-seq is dominated by technical noise (dropout often >90%); reconstruction-based single-cell foundation models end up encoding that noise.
Key ideaPredict the representation of masked gene programmes from a partial cell profile, against a stable teacher embedding — learn the latent cell state, not the raw counts.
ModalitySingle-cell transcriptomics (scRNA-seq)
Target / maskingMask gene modules / programmes; an EMA teacher embeds the full profile to form targets.
Builds onI-JEPA's masked latent-prediction recipe.
Used forCell-type / state representation, zero-shot clustering, transfer.

Motivation

A single cell's expression vector is sparse and noisy: a gene reading zero usually means it was not captured, not that it is off. Models trained to reconstruct counts (autoencoders, many single-cell foundation models) are therefore pushed to fit dropout and batch artefacts. Cell-JEPA argues the predictable, biologically meaningful signal lives at the level of programmes — coordinated sets of genes — and should be modelled in latent space, where unpredictable per-gene noise can simply be discarded.

How it works

scRNA-seq cellgene modules · masked modulesContext encoderf_θTarget encoderf̄_θ · EMAPredictorg_φlatent loss‖ẑ − sg(z̄)‖²z_ctxz̄ (sg)EMA copy
Canonical JEPA schematic for scRNA-seq cell. The input is split into a visible context and hidden targets (gene module-level, masked modules). The context encoder $f_\theta$ embeds what is visible; the target encoder $\bar f_\theta$ (an EMA copy, gradient stopped) embeds the targets; the predictor $g_\phi$ maps context to the target embeddings; training minimises the latent distance.

An expression profile is tokenised into gene-module tokens (co-expression modules, pathways, or programmes). A subset of modules is masked.

  • The context encoder embeds the visible (student) modules.
  • An EMA target/teacher encoder embeds the full, unmasked profile; its embeddings at the masked positions are the targets.
  • A predictor reconstructs the masked module representations from the visible context.

Because the target is a teacher embedding rather than raw counts, the model is rewarded for predicting the state of a masked programme (active/inactive, which regime) rather than the exact, dropout-corrupted numbers — yielding dropout-robust cell representations.

The objective

Standard JEPA latent loss over masked gene modules $k$:

$$\mathcal{L} = \sum_{k\in\text{mask}} \big\lVert\, g_\phi(z_{\text{ctx}}, m_k) - \operatorname{sg}[\bar f_\theta(x)_k] \,\big\rVert_2^2,$$

with the teacher $\bar f_\theta$ an EMA of the student. No negative pairs and no augmentations are needed — which matters because valid augmentations for a cell are unknown and naive ones encode batch effects.

Key results & what's novel

Cell-JEPA reports strong unsupervised structure: 0.72 AvgBIO zero-shot cell-type clustering versus 0.53 for scGPT, consistent with the claim that latent prediction resists dropout better than reconstruction. An important honest finding accompanies it: the approach improves absolute-state reconstruction more than it improves perturbation effect-size estimation. The practical reading is that Cell-JEPA is an excellent state encoder, but predicting how a cell responds to an intervention is a different objective that needs an action-conditioned model on top.

Strengths & limitations

  • + Dropout- and noise-robust representations; no augmentations or negatives.
  • + Strong zero-shot cell-type organisation; a clean substrate for downstream models.
  • Representation quality does not translate into accurate perturbation effect sizes — it is a state encoder, not a world model.
  • Results were characterised in a limited (e.g. single cell-line) setting; broad generalisation is still to be established.
  • Performance depends on how gene modules are defined.

Connections & references

Builds onI-JEPA