Motivation
Building a foundation model for brain activity dynamics (functional neuroimaging time series) runs into four difficulties at once: high dimensionality, low signal-to-noise, subject heterogeneity, and the problem of encoding the brain's functional organisation in a form a model can exploit. Arbitrary indexing of brain regions throws away the connectivity structure that gives the signal meaning. Brain-JEPA (Dong et al., NeurIPS 2024) aims for transferable representations of brain dynamics learned self-supervised, with the brain's functional layout built into the model.
How it works
Brain-JEPA applies the joint-embedding predictive recipe to brain signals.
- A context encoder embeds visible spatiotemporal patches.
- A predictor predicts the latent representations of masked regions.
- A target encoder supplies the targets via a latent prediction loss.
Two domain-specific innovations stand out. Gradient positioning encodes the functional organisation of brain regions — functional-gradient coordinates — as positional information, so the model respects connectivity structure rather than arbitrary indexing. A tailored spatiotemporal masking strategy masks across regions and time. Together they let the model predict the latent state of masked brain regions in a functionally informed coordinate system.
The objective
The loss is the latent distance over masked spatiotemporal brain patches:
$$\mathcal{L} = \sum_{k\in\text{mask}} \big\lVert\, g_\phi(z_{\text{ctx}}, m_k) - \operatorname{sg}[\bar f_\theta(x)_k]\,\big\rVert_2^2,$$
with predictor $g_\phi$, stop-gradient $\operatorname{sg}$, and target encoder $\bar f_\theta$. The positional encoding inside $f_\theta$ is the functional-gradient coordinate, so the prediction task is posed in a space that reflects brain connectivity rather than raw region order.
Key results & what's novel
The key contribution is a brain-dynamics foundation model whose functionally informed positioning and masking yield strong, transferable representations. The novelty is the pairing of two domain choices with the JEPA recipe: gradient positioning imports neuroscience structure into the model's geometry, and the spatiotemporal masking is tailored to region-and-time data rather than copied from vision. Learning from large unlabelled neuroimaging corpora, Brain-JEPA reduces the labelled-data burden that otherwise constrains downstream brain-dynamics modelling.
Strengths & limitations
- + Functional-gradient positioning encodes brain connectivity rather than arbitrary region indices.
- + Tailored spatiotemporal masking suited to neuroimaging data.
- + Transferable representations learned from unlabelled neuroimaging corpora.
- − Depends on the choice of functional atlas / gradient coordinates.
- − Learns a representation of dynamics, not an action-conditioned world model.
- − fMRI's intrinsic low SNR and subject heterogeneity still bound achievable quality.