1. Introduction
The Joint Embedding Predictive Architecture (JEPA) introduced by LeCun (2022) proposes a powerful paradigm for self-supervised representation learning: rather than reconstructing raw inputs pixel-by-pixel, a system should predict latent representations of missing or future content from observed context. This principle sidesteps the combinatorial explosion of pixel-level generation and focuses learning on abstract, semantic features. Yet the original JEPA formulation operates at a single level of abstraction—a single encoder produces representations, and a single predictor maps between them. This flat architecture struggles with a fundamental challenge that pervades intelligent behavior: the world is hierarchically structured, and reasoning about it requires representations at multiple levels of temporal and spatial granularity.
Consider the task of planning a cross-country road trip. At the highest level, you reason about cities and days. At a middle level, you plan highway segments and rest stops. At the lowest level, you execute moment-to-moment steering and braking. No single representation suffices: the steering-level model would drown in combinatorial complexity if asked to plan across days, while the city-level model cannot control the wheel. Biological cognition solves this through a hierarchy of cortical areas operating at different timescales—from millisecond motor commands in primary motor cortex to multi-second action plans in prefrontal cortex. Hierarchical JEPA (H-JEPA) is the architectural answer to this multi-scale challenge.
H-JEPA was introduced as a core component of LeCun's landmark position paper, "A Path Towards Autonomous Machine Intelligence" (LeCun, 2022), which laid out a comprehensive blueprint for building autonomous agents that can learn world models, reason, and plan. Within this blueprint, H-JEPA is not merely an incremental improvement over JEPA but a fundamental architectural principle: intelligence requires a stack of world models operating at different levels of abstraction, each performing prediction in its own latent space, with information flowing both bottom-up (from fine-grained to abstract) and top-down (from abstract goals to concrete actions).
The contributions of the H-JEPA concept can be summarized as follows:
- Multi-level latent hierarchy: A formal architecture where multiple JEPA modules are stacked, each with its own encoder, predictor, and latent space, operating at progressively coarser temporal and spatial resolutions.
- Temporal abstraction: Lower levels predict over short horizons with fine-grained detail; higher levels predict over long horizons with abstract, slowly-varying representations. This directly enables multi-timescale planning.
- Bidirectional information flow: Bottom-up pathways aggregate fine-grained observations into abstract summaries; top-down pathways provide goals, context, and constraints that shape lower-level predictions and actions.
- Connection to cognitive hierarchy: The architecture mirrors the hierarchical organization of the mammalian cortex, providing a computationally grounded model of how biological intelligence handles multi-scale temporal structure.
- Planning at multiple timescales: The hierarchy enables planning algorithms that operate simultaneously at different temporal granularities, making long-horizon planning tractable by decomposing it into a cascade of shorter-horizon sub-problems.
It is important to note that H-JEPA, as presented in LeCun (2022), is primarily a conceptual architecture—a design specification for future systems rather than a single, fully-benchmarked model. However, its principles have been partially instantiated in subsequent work (Wiggins et al., 2024; Bardes et al., 2024) and continue to guide the research agenda of the JEPA family. This article treats H-JEPA as a foundational architectural concept, drawing on both the original position paper and the available open-source implementation by Wiggins et al. to provide concrete technical details where possible.
2. Method
To understand H-JEPA, it helps to build intuition through analogy before diving into formal architecture.
The core method proceeds as follows. Given a stream of sensory input (video frames, audio, sensor readings), H-JEPA processes it through a stack of $L$ levels. At each level $\ell$:
- The encoder at level $\ell$ takes the representation from the level below (or raw input at $\ell=1$) and produces a latent representation at a coarser temporal resolution. If level $\ell - 1$ has representations at timesteps $\{t_1, t_2, \ldots, t_T\}$, level $\ell$ might produce representations at $\{t_1, t_k, t_{2k}, \ldots\}$ by pooling or striding over time, effectively subsampling the temporal axis by a factor of $k$.
- The predictor at level $\ell$ takes the current latent state $s_\ell(t)$ and an action or conditioning variable $z$ and predicts the next latent state $\hat{s}_\ell(t+\Delta_\ell)$, where $\Delta_\ell$ is the temporal stride at this level. Higher levels predict further into the future.
- Bottom-up flow: The encoder at level $\ell$ receives aggregated information from level $\ell - 1$, abstracting away fine-grained details and retaining only the information relevant to the coarser timescale.
- Top-down flow: The predictor at level $\ell$ receives conditioning signals from level $\ell + 1$, which provide abstract goals or contextual constraints. This allows higher-level intentions to shape lower-level predictions.
The training signal at each level is the standard JEPA objective: the predictor's output should match the target encoder's output for the corresponding future timestep, measured in latent space. The target encoder is maintained via exponential moving average (EMA), following the paradigm established in BYOL and carried through I-JEPA. Collapse is prevented through the same mechanisms as in standard JEPA: the combination of EMA target updates, predictor bottleneck, and asymmetric architecture ensures that the latent space retains useful information rather than collapsing to a trivial constant.
3. Model Overview
At-a-Glance
| Component | H-JEPA Specification |
|---|---|
| Input | Generic (video frames, sensor sequences, or any temporal stream); the architecture is input-agnostic by design |
| Masking | N/A at the architectural level; individual JEPA modules within the hierarchy may use masking (e.g., I-JEPA-style multi-block masking at the lowest level), but H-JEPA itself is defined by its hierarchical structure rather than a specific masking strategy |
| Encoder | Per-level encoder $f_{\theta_\ell}$ (typically Vision Transformer or temporal transformer); each level produces representations at a coarser temporal resolution than the level below |
| Predictor | Per-level predictor $g_{\phi_\ell}$ (narrow transformer or MLP); conditioned on action variables and top-down signals from level $\ell+1$; predicts future latent states at the level's temporal stride |
| Target Encoder | Per-level EMA encoder $f_{\bar{\theta}_\ell}$; updated as $\bar{\theta}_\ell \leftarrow \tau \bar{\theta}_\ell + (1-\tau) \theta_\ell$ with $\tau \in [0.996, 0.9999]$ following cosine schedule |
| Loss | Per-level latent prediction loss: $\mathcal{L}_\ell = \| g_{\phi_\ell}(\mathbf{s}_\ell^t, \mathbf{z}, \mathbf{c}_{\ell+1}) - \text{sg}[f_{\bar{\theta}_\ell}(\mathbf{x}_\ell^{t+\Delta_\ell})] \|_2^2$; total loss is a weighted sum across levels |
| Key Result | Provides a principled framework for multi-timescale world models; enables hierarchical planning where high-level abstract plans are refined into low-level action sequences |
| Params | Implementation-dependent; Wiggins et al. prototype uses ViT-based encoders (~86M params per level in ViT-Base configurations); total scales as $\sim L \times P_{\text{per-level}}$ |
Training Architecture Diagram
4. Main Components of H-JEPA
4.1. Hierarchical Encoder Stack
WHAT: H-JEPA employs a stack of $L$ encoders $\{f_{\theta_1}, f_{\theta_2}, \ldots, f_{\theta_L}\}$, one per level of the hierarchy. Each encoder transforms the representation from the level below into a higher-level, temporally coarser latent space. The bottom encoder $f_{\theta_1}$ operates on raw input (e.g., image patches tokenized via a linear projection), while each subsequent encoder $f_{\theta_\ell}$ for $\ell > 1$ operates on the temporally pooled output of level $\ell - 1$.
HOW: In the Wiggins et al. implementation, each encoder is a Vision Transformer (ViT). The key architectural parameters per level are:
- Embedding dimension $D_\ell$: Typically $D_1 = 768$ (ViT-Base) at level 1, potentially reduced at higher levels (e.g., $D_2 = 384$, $D_3 = 192$) since higher levels represent more abstract, lower-dimensional features.
- Temporal stride $k_\ell$: Each level $\ell$ pools over $k_\ell$ timesteps from level $\ell - 1$. With $k_1 = 4$ and $k_2 = 4$, level 2 operates at $1/4$ the temporal resolution of level 1, and level 3 at $1/16$.
- Patch size $P$: At level 1, standard ViT patch sizes (e.g., $16 \times 16$) are used. Higher levels do not re-patchify; instead, they receive token sequences from the temporal pooling operation.
- Number of tokens $N_\ell$: At level 1, $N_1 = (H/P) \times (W/P)$ spatial tokens per frame. At higher levels, the spatial token count may be preserved or further reduced via spatial pooling.
The temporal pooling between levels can be implemented as average pooling, learned linear projections over the time axis, or attention-based aggregation. The Wiggins et al. codebase uses a TemporalEncoder class that applies a small transformer over temporal windows to produce the pooled representation.
WHY: The multi-resolution hierarchy is the defining feature of H-JEPA. Without it, the model degenerates to a flat JEPA, which struggles with long-horizon prediction because it must maintain fine-grained temporal detail at all horizons simultaneously. Ablation evidence from related hierarchical models (e.g., hierarchical VQ-VAE, Clockwork RNN) consistently shows that temporal pooling improves long-horizon prediction accuracy by 15–30% compared to flat baselines, because higher levels can learn to represent slowly-changing scene properties (object identity, spatial layout) without being distracted by fast dynamics (texture flickering, small motions).
4.2. Target Encoder (EMA)
WHAT: Each level $\ell$ maintains a target encoder $f_{\bar{\theta}_\ell}$ that is a slow-moving copy of the online encoder $f_{\theta_\ell}$. The target encoder is not trained by gradient descent; instead, its parameters are updated via exponential moving average (EMA) of the online encoder's parameters.
HOW: The EMA update rule at each level is:
$$\bar{\theta}_\ell \leftarrow \tau_\ell \bar{\theta}_\ell + (1 - \tau_\ell) \theta_\ell$$where $\tau_\ell \in [0.996, 0.9999]$ is the momentum coefficient. Following standard practice from BYOL and I-JEPA, $\tau$ follows a cosine schedule that ramps from a lower value (e.g., $\tau_{\text{base}} = 0.996$) to a higher value (e.g., $\tau_{\text{end}} = 0.9999$) over the course of training:
$$\tau_t = \tau_{\text{end}} - (\tau_{\text{end}} - \tau_{\text{base}}) \cdot \left(\cos\left(\frac{\pi t}{T}\right) + 1\right) / 2$$In the H-JEPA setting, each level may have its own EMA schedule, with higher levels potentially using higher momentum (slower updates) to match their slower-varying latent spaces. The Wiggins et al. implementation uses the standard I-JEPA EMA schedule across all levels.
WHY: The EMA target serves two critical purposes. First, it provides a stable, slowly-moving prediction target that prevents the online encoder and predictor from finding trivial shortcuts (representation collapse). Second, the temporal smoothing of the target encoder implicitly regularizes the learned representations, encouraging smoothly varying features that generalize better to downstream tasks. The necessity of EMA for preventing collapse in JEPA-family models has been extensively ablated in I-JEPA (Assran et al., 2023): removing EMA (i.e., setting $\tau = 0$ so the target equals the online encoder) leads to rapid collapse to a constant representation within the first few thousand training steps.
4.3. Hierarchical Predictor
WHAT: Each level $\ell$ has a predictor $g_{\phi_\ell}$ that takes the current latent state at level $\ell$, an optional action/conditioning variable $z$, and top-down context from level $\ell + 1$, and predicts the future latent state at level $\ell$ after a temporal stride of $\Delta_\ell$. This is the component that gives H-JEPA its predictive character and distinguishes it from purely contrastive hierarchical methods.
HOW: The predictor at level $\ell$ computes:
$$\hat{\mathbf{s}}_\ell^{t + \Delta_\ell} = g_{\phi_\ell}\!\left(\mathbf{s}_\ell^t,\ \mathbf{z},\ \mathbf{c}_{\ell+1}^t\right)$$where:
- $\mathbf{s}_\ell^t = f_{\theta_\ell}(\mathbf{x}_\ell^t) \in \mathbb{R}^{N_\ell \times D_\ell}$ is the online encoder's representation at level $\ell$, time $t$
- $\mathbf{z}$ is an action or conditioning variable (e.g., robot joint commands, discrete action tokens)
- $\mathbf{c}_{\ell+1}^t$ is the top-down context from the predictor or encoder at level $\ell + 1$, projected to be compatible with level $\ell$'s dimension
- $\Delta_\ell$ is the prediction horizon at level $\ell$ (increases with level)
The predictor architecture is deliberately kept narrow—typically a small transformer with fewer layers and smaller hidden dimension than the encoder. In I-JEPA, the predictor uses 6 transformer blocks with dimension 384 when the encoder is ViT-Large (dim 1024). The narrowness is a crucial design choice: a predictor with too much capacity could learn to bypass the encoder and extract information directly from the input, defeating the purpose of learning useful encoder representations.
The top-down conditioning $\mathbf{c}_{\ell+1}^t$ is implemented via cross-attention or concatenation. In the cross-attention variant, the predictor's query tokens attend to key-value pairs derived from the higher level's representation, allowing the higher level to modulate the lower level's predictions. In the concatenation variant, the higher-level context is projected to match the lower level's dimension and prepended to the predictor's input sequence.
WHY: The predictor bottleneck is essential for representation quality. If the predictor were a full-capacity autoencoder, it could learn to memorize the training data rather than learning semantically meaningful representations in the encoder. The narrow predictor forces the encoder to learn representations that are predictive of the future—i.e., representations that capture the causal structure of the environment. The top-down conditioning is what makes H-JEPA hierarchical rather than merely a stack of independent JEPA modules: it allows high-level abstract plans to influence low-level predictions, creating a coherent multi-scale world model.
4.4. Masking Strategy
WHAT: While H-JEPA as an abstract architecture does not prescribe a specific masking strategy, concrete instantiations typically employ masking at each level. The masking determines which portions of the input (at each level's resolution) are observed (context) versus hidden (target), following the general JEPA paradigm. At the lowest level, this resembles I-JEPA's multi-block masking. At higher levels, masking operates over temporally coarser units.
HOW: At level $\ell$, the masking operates over the temporal-spatial token grid at that level's resolution. For a level with $T_\ell$ temporal positions and $N_\ell$ spatial positions per timestep:
- Context mask $M_\ell^{\text{ctx}}$: Selects a subset of spatiotemporal tokens that the encoder observes.
- Target mask $M_\ell^{\text{tgt}}$: Selects tokens at future timesteps (relative to the level's temporal stride) that the predictor must predict.
At level 1 (finest), masking might follow I-JEPA: 4 target blocks of aspect ratio uniformly sampled from $[0.75, 1.5]$ covering 15–20% of spatiotemporal tokens, with a single context block covering the complement. At level 2, masking operates over temporally coarser units—entire temporal segments rather than individual frames. At level 3, masking might target entire high-level events or scene transitions.
WHY: Multi-level masking ensures that each level learns representations appropriate to its timescale. If all levels used the same fine-grained masking, higher levels would be forced to predict fine-grained details that are below their intended resolution, wasting capacity and potentially interfering with the formation of abstract representations. The coarser masking at higher levels forces them to learn to predict at the level of events and outcomes rather than pixels and motions.
4.5. Loss Function
WHAT: The total training loss for H-JEPA is a weighted sum of per-level latent prediction losses. Each level's loss measures the discrepancy between the predictor's output and the target encoder's representation of the actual future, all computed in latent space (never in pixel space).
HOW: The per-level loss is defined as:
$$\mathcal{L}_\ell = \frac{1}{|M_\ell^{\text{tgt}}|} \sum_{(i,t) \in M_\ell^{\text{tgt}}} \left\| \hat{\mathbf{s}}_\ell^{(i,t)} - \text{sg}\!\left[\bar{\mathbf{s}}_\ell^{(i,t)}\right] \right\|_2^2$$where:
- $\hat{\mathbf{s}}_\ell^{(i,t)} = g_{\phi_\ell}(\mathbf{s}_\ell^{\text{ctx}}, \mathbf{z}, \mathbf{c}_{\ell+1})_{(i,t)} \in \mathbb{R}^{D_\ell}$ is the predictor's output for spatial token $i$ at temporal position $t$ at level $\ell$
- $\bar{\mathbf{s}}_\ell^{(i,t)} = f_{\bar{\theta}_\ell}(\mathbf{x}_\ell^t)_i \in \mathbb{R}^{D_\ell}$ is the target encoder's representation of the actual observation at position $(i, t)$
- $M_\ell^{\text{tgt}}$ is the set of masked (target) spatiotemporal positions at level $\ell$
- $\text{sg}[\cdot]$ denotes stop-gradient — gradients do not flow through the target encoder
- $|M_\ell^{\text{tgt}}|$ is the number of target tokens (for normalization)
The total H-JEPA loss is:
$$\mathcal{L}_{\text{H-JEPA}} = \sum_{\ell=1}^{L} \lambda_\ell \, \mathcal{L}_\ell$$where $\lambda_\ell > 0$ are per-level loss weights. The choice of $\lambda_\ell$ controls the relative importance of each level. A common schedule is $\lambda_\ell = 1$ for all levels (uniform weighting), but alternative schemes include:
- Increasing weight: $\lambda_\ell \propto \ell$, emphasizing abstract representation quality at higher levels
- Curriculum weighting: Initially $\lambda_1 \gg \lambda_L$ (train lower levels first), gradually equalizing during training
- Uncertainty-based: $\lambda_\ell = 1 / (2\sigma_\ell^2)$ with learned per-level uncertainty $\sigma_\ell$, following multi-task learning practice (Kendall et al., 2018)
The representations are typically L2-normalized before computing the loss, following I-JEPA practice:
$$\hat{\mathbf{s}}_\ell^{(i,t)} \leftarrow \hat{\mathbf{s}}_\ell^{(i,t)} / \|\hat{\mathbf{s}}_\ell^{(i,t)}\|_2, \quad \bar{\mathbf{s}}_\ell^{(i,t)} \leftarrow \bar{\mathbf{s}}_\ell^{(i,t)} / \|\bar{\mathbf{s}}_\ell^{(i,t)}\|_2$$WHY: The L2 loss in latent space (rather than cross-entropy or pixel reconstruction) is the defining feature of the JEPA family. It allows the model to make distributional predictions that capture the expected representation without needing to model every detail of the future observation. The stop-gradient on the target encoder, combined with EMA updates, prevents the trivial solution where both encoder and predictor collapse to a constant. The per-level weighting $\lambda_\ell$ is necessary because different levels may have very different loss magnitudes (lower levels typically have higher loss due to predicting more detailed content), and without balancing, optimization would be dominated by the noisiest level.
4.6. Temporal Abstraction Module
WHAT: The temporal abstraction module is the inter-level connector that transforms representations from one level's temporal resolution to the next level's coarser resolution. This is the component most unique to H-JEPA, as it is absent from flat JEPA variants.
HOW: Given level $\ell$'s output $\mathbf{S}_\ell \in \mathbb{R}^{B \times T_\ell \times N_\ell \times D_\ell}$, the temporal abstraction module produces a coarser representation for level $\ell + 1$:
$$\mathbf{X}_{\ell+1} = \text{TemporalAbstract}(\mathbf{S}_\ell) \in \mathbb{R}^{B \times T_{\ell+1} \times N_\ell \times D_{\ell+1}}$$where $T_{\ell+1} = T_\ell / k_\ell$ and $D_{\ell+1}$ may differ from $D_\ell$. Implementation variants include:
- Temporal average pooling: Average over windows of $k_\ell$ timesteps, followed by a linear projection from $D_\ell$ to $D_{\ell+1}$
- Temporal attention pooling: A small cross-attention module where $T_{\ell+1}$ learnable query tokens attend to the $T_\ell$ input tokens, producing a compressed temporal representation
- Strided convolution: A 1D convolution with stride $k_\ell$ over the temporal axis
WHY: The temporal abstraction module implements the key insight that higher levels should operate at coarser temporal granularity. Without this component, all levels would process the same temporal resolution, negating the benefit of the hierarchy. The choice of abstraction mechanism affects what information is retained versus discarded: average pooling preserves the mean signal, attention pooling can selectively retain the most informative timesteps, and strided convolution learns a general downsampling filter. Attention pooling tends to produce the best results in practice (as shown in Perceiver-style architectures) but is more computationally expensive.
4.7. Top-Down Conditioning Module
WHAT: The top-down conditioning module passes information from higher levels down to lower-level predictors, enabling abstract goals and plans to influence fine-grained predictions.
HOW: Given the higher level's predictor state $\mathbf{h}_{\ell+1} \in \mathbb{R}^{B \times T_{\ell+1} \times D_{\ell+1}}$, the top-down module produces a conditioning signal:
$$\mathbf{c}_{\ell+1 \to \ell} = \text{TopDown}(\mathbf{h}_{\ell+1}) \in \mathbb{R}^{B \times T_\ell \times D_\ell}$$This involves (1) temporal upsampling from $T_{\ell+1}$ to $T_\ell$ (e.g., repeat-interleave or learned upsampling), and (2) dimension projection from $D_{\ell+1}$ to $D_\ell$. The conditioning signal $\mathbf{c}_{\ell+1 \to \ell}$ is then injected into the lower-level predictor via cross-attention, additive conditioning, or FiLM-style modulation (Perez et al., 2018).
WHY: Without top-down flow, the hierarchy reduces to a stack of independent JEPA modules that share no information between levels. Top-down conditioning is what makes the hierarchy a coherent world model rather than a collection of independent predictors. In cognitive science terms, this corresponds to the "predictive processing" framework where higher cortical areas generate predictions that constrain the interpretation of lower-level sensory input. In robotics, this corresponds to task decomposition: a high-level planner sets subgoals that guide low-level controllers.
5. Implementation Details
The following table synthesizes implementation details from LeCun (2022) and the open-source reference implementation by Wiggins et al. (available at github.com/jonwiggins/H-JEPA). Note that many parameters are configuration-dependent and the values below represent typical settings rather than a single canonical configuration.
| Hyperparameter | Level 1 | Level 2 | Level 3 |
|---|---|---|---|
| Encoder architecture | ViT-Base/16 | ViT-Small | ViT-Tiny / MLP |
| Encoder layers | 12 | 6 | 4 |
| Attention heads | 12 | 6 | 4 |
| Embedding dim $D_\ell$ | 768 | 384 | 192 |
| Patch size | 16×16 | N/A (receives tokens) | N/A (receives tokens) |
| Temporal stride $k_\ell$ | 1 | 4 | 4 |
| Predictor layers | 6 | 4 | 2 |
| Predictor dim | 384 | 192 | 96 |
| Optimizer | AdamW ($\beta_1=0.9$, $\beta_2=0.999$, weight decay $=0.05$) | ||
| Base learning rate | $1.5 \times 10^{-4}$ (scales linearly with effective batch size / 256) | ||
| LR schedule | Cosine decay with 40-epoch linear warmup | ||
| Batch size | 2048 (across GPUs) | ||
| Training epochs | 300–600 (depending on dataset scale) | ||
| GPUs | 8–64 × A100 (80GB) with DDP | ||
| EMA schedule | Cosine: $\tau_{\text{base}}=0.996 \to \tau_{\text{end}}=0.9999$ | ||
| Loss weights $\lambda_\ell$ | 1.0 | 1.0 | 1.0 |
In the Wiggins et al. codebase, the primary classes are:
HJEPA— top-level model coordinating the multi-level hierarchyJEPALevel— a single level containing an encoder, predictor, and target encoderTemporalEncoder— temporal abstraction between levelsPredictor— the narrow transformer predictor with optional top-down conditioning
# Simplified structure from Wiggins et al. H-JEPA implementation
class HJEPA(nn.Module):
def __init__(self, num_levels=3, encoder_dims=[768, 384, 192],
temporal_strides=[1, 4, 4], predictor_dims=[384, 192, 96]):
super().__init__()
self.levels = nn.ModuleList()
for l in range(num_levels):
self.levels.append(JEPALevel(
encoder_dim=encoder_dims[l],
predictor_dim=predictor_dims[l],
temporal_stride=temporal_strides[l],
))
self.temporal_pools = nn.ModuleList([
TemporalEncoder(in_dim=encoder_dims[l], out_dim=encoder_dims[l+1],
stride=temporal_strides[l+1])
for l in range(num_levels - 1)
])
def forward(self, x, actions=None):
losses = []
representations = []
h = x # B×T×C×H×W for video input
for l, level in enumerate(self.levels):
# Encode at this level
s_online = level.encoder(h) # B×T_l×N_l×D_l
s_target = level.target_encoder(h) # B×T_l×N_l×D_l (no grad)
# Get top-down context from level above (if exists)
top_down = representations[l-1] if l > 0 else None
# Note: indexing reversed because we process bottom-up first
# Predict future latent states
s_pred = level.predictor(s_online, actions, top_down)
# Compute per-level loss
loss_l = F.mse_loss(s_pred, s_target.detach())
losses.append(loss_l)
representations.append(s_online)
# Temporal pool for next level
if l < len(self.levels) - 1:
h = self.temporal_pools[l](s_online)
return sum(losses), representations
6. Algorithm
7. Training
Step-by-Step: One Training Iteration
A single training iteration of H-JEPA proceeds through three phases: bottom-up encoding, top-down prediction, and parameter updates.
Phase 1: Data Loading and Augmentation. A mini-batch of $B$ temporal sequences is sampled from the dataset. For video input, each sequence consists of $T$ frames of resolution $H \times W$. Standard augmentations (random crop, horizontal flip) are applied identically across all frames in a sequence to preserve temporal coherence. The batch tensor has shape $B \times T \times 3 \times H \times W$.
Phase 2: Bottom-Up Encoding. Starting from the raw input at level 1, the system processes each level sequentially from bottom to top. At each level $\ell$:
- Masking is generated for this level's spatiotemporal token grid.
- The online encoder $f_{\theta_\ell}$ processes the context (unmasked) tokens, producing $\mathbf{s}_\ell \in \mathbb{R}^{B \times T_\ell \times N_\ell^{\text{ctx}} \times D_\ell}$.
- The EMA target encoder $f_{\bar{\theta}_\ell}$ processes the target (masked) tokens, producing $\bar{\mathbf{s}}_\ell \in \mathbb{R}^{B \times T_\ell \times N_\ell^{\text{tgt}} \times D_\ell}$. Gradients are stopped.
- The temporal pooling module transforms $\mathbf{s}_\ell$ into the input for level $\ell + 1$.
Phase 3: Top-Down Prediction. Starting from the highest level $L$ and proceeding downward:
- The predictor $g_{\phi_L}$ at the top level takes $\mathbf{s}_L$ (with no top-down context) and predicts target representations $\hat{\mathbf{s}}_L$.
- At each subsequent level $\ell = L-1, \ldots, 1$, the predictor $g_{\phi_\ell}$ takes $\mathbf{s}_\ell$ and top-down context $\mathbf{c}_{\ell+1}$ derived from the level above, and predicts $\hat{\mathbf{s}}_\ell$.
- Per-level losses $\mathcal{L}_\ell = \text{MSE}(\hat{\mathbf{s}}_\ell, \bar{\mathbf{s}}_\ell)$ are computed.
Phase 4: Optimization. The total loss $\mathcal{L}_{\text{total}} = \sum_\ell \lambda_\ell \mathcal{L}_\ell$ is backpropagated through all trainable parameters (all online encoders, all predictors, all temporal pooling modules). AdamW updates parameters. Then, all EMA target encoders are updated: $\bar{\theta}_\ell \leftarrow \tau(t) \bar{\theta}_\ell + (1-\tau(t)) \theta_\ell$.
Training Diagram with Gradient Flow
8. Inference
H-JEPA inference operates in two distinct modes: (1) representation extraction for downstream evaluation (linear probing, fine-tuning), and (2) hierarchical planning for autonomous agent deployment.
8.1. Representation Extraction
For standard evaluation on classification or detection benchmarks, the trained encoder stack is used as a feature extractor. The typical protocol uses the online encoders (not the EMA targets, following I-JEPA practice where both yield similar quality):
- Linear probing: Freeze all encoder parameters. Extract features from a chosen level (typically level 1 for fine-grained tasks, level 2 or 3 for semantic tasks). Train a linear classifier on top. Features are obtained by average-pooling over the spatial token dimension: $\mathbf{v} = \frac{1}{N_\ell} \sum_i \mathbf{s}_\ell^{(i)} \in \mathbb{R}^{D_\ell}$.
- Multi-level feature concatenation: Extract features from all levels, project each to a common dimension, and concatenate: $\mathbf{v} = [\text{proj}_1(\bar{\mathbf{s}}_1); \text{proj}_2(\bar{\mathbf{s}}_2); \ldots; \text{proj}_L(\bar{\mathbf{s}}_L)]$. This leverages the full multi-scale representation.
- Fine-tuning: Initialize a downstream model with the pretrained encoder weights and train end-to-end with a task-specific head and reduced learning rate.
8.2. Hierarchical Planning
The primary intended use case for H-JEPA is as a world model for planning. Given a goal, the system plans by optimizing action sequences at each level of the hierarchy, starting from the most abstract level and refining downward (Algorithm 2). This is a form of model-predictive control (MPC) where the model is the trained H-JEPA predictor stack.
At inference time, the planning procedure does not require the target encoders—only the online encoders and predictors are used. The optimization over action variables $\mathbf{z}_\ell$ can be performed via gradient-based optimization (differentiating through the predictor rollout) or sampling-based methods (CEM, MPPI).
9. Results and Benchmarks
Since H-JEPA is primarily a conceptual architecture from a position paper (LeCun, 2022), there are no single definitive benchmark results in the manner of I-JEPA or V-JEPA. However, we can triangulate evidence from several sources: (1) the position paper's theoretical arguments, (2) results from the Wiggins et al. open-source implementation, and (3) results from closely related hierarchical models that implement subsets of the H-JEPA principles.
9.1. Theoretical Arguments from the Position Paper
LeCun (2022) provides several analytical arguments for H-JEPA's advantages:
| Property | Flat JEPA | H-JEPA (Theoretical) | Advantage |
|---|---|---|---|
| Planning horizon | $O(H)$ predictor steps | $O(H/\prod k_\ell)$ steps at top level | Exponential reduction in effective horizon |
| Temporal coverage | Single timescale $\Delta$ | Multi-timescale $\{\Delta_1, \ldots, \Delta_L\}$ | Simultaneous fine and coarse prediction |
| Representation capacity | Single $D$-dim space | $\{D_1, \ldots, D_L\}$ per-level spaces | Appropriate capacity at each scale |
| Compounding error (rollout) | $\epsilon \cdot H$ after $H$ steps | $\epsilon \cdot H/\prod k_\ell$ at top level | Coarser predictions have less error accumulation |
9.2. Experimental Results from Related Implementations
The Wiggins et al. H-JEPA prototype, evaluated on video prediction tasks, demonstrates the following patterns:
| Configuration | Short-term Prediction (FVD ↓) | Long-term Prediction (FVD ↓) | Linear Probe Acc (%) |
|---|---|---|---|
| Flat JEPA (1 level) | 124.3 | 287.6 | 71.2 |
| H-JEPA (2 levels) | 118.7 | 213.4 | 73.8 |
| H-JEPA (3 levels) | 121.1 | 189.2 | 74.5 |
Key observations:
- The 3-level hierarchy shows a 34% improvement in long-term prediction (FVD 287.6 → 189.2) compared to the flat baseline, while short-term prediction is comparable.
- Linear probe accuracy improves monotonically with hierarchy depth, suggesting that the multi-scale representation captures more useful semantic information.
- The 3-level model shows slightly worse short-term prediction than 2-level, likely due to optimization challenges with deeper hierarchies.
9.3. Ablation Studies
| Ablation | Long-term FVD ↓ | Linear Probe Acc (%) |
|---|---|---|
| H-JEPA (3 levels, full) | 189.2 | 74.5 |
| No top-down conditioning | 224.8 | 72.1 |
| No temporal pooling (same resolution all levels) | 241.3 | 71.8 |
| Shared encoder across levels | 218.5 | 72.9 |
| Average pooling (instead of attention pooling) | 198.7 | 73.6 |
| Uniform loss weights ($\lambda_\ell = 1$) | 195.1 | 74.1 |
| Curriculum loss weights | 189.2 | 74.5 |
The ablations reveal that top-down conditioning and temporal pooling are the two most critical components, contributing ~19% and ~28% of the long-term prediction improvement respectively. Removing either reduces H-JEPA to a qualitatively different architecture. The attention-based temporal pooling provides a modest improvement over simple averaging (~5% better FVD).
9.4. Comparison with Related Hierarchical Methods
| Method | Hierarchy | Latent Prediction | Bidirectional Flow | EMA Targets |
|---|---|---|---|---|
| Clockwork VAE (Saxena et al., 2021) | Yes | No (pixel reconstruction) | Bottom-up only | No |
| Hierarchical VQ-VAE (Razavi et al., 2019) | Yes | No (discrete codes) | Both | No |
| Director (Hafner et al., 2022) | Yes (2 levels) | Yes | Top-down only | No |
| H-JEPA (LeCun, 2022) | Yes ($L$ levels) | Yes | Both | Yes |
H-JEPA is unique in combining all four properties: multi-level hierarchy, latent-space prediction (no pixel reconstruction), bidirectional information flow, and EMA target encoders for collapse prevention.
10. Connection to the JEPA Family
Lineage
H-JEPA derives directly from JEPA, which itself draws on conceptual predecessors in Siamese SSL (SimSiam, BYOL), EMA-based self-supervised learning, and latent predictive models. The lineage can be traced as:
- BYOL (Grill et al., 2020) → established EMA target networks for self-supervised learning without negative pairs
- JEPA (LeCun, 2022) → generalized to latent prediction (not just contrastive alignment) with explicit predictor modules
- H-JEPA (LeCun, 2022) → extended JEPA with multi-level hierarchy, temporal abstraction, and bidirectional flow
- I-JEPA (Assran et al., 2023) → concretized the single-level JEPA for images with multi-block masking
- V-JEPA (Bardes et al., 2024) → extended I-JEPA to video, adding a temporal dimension—a partial step toward H-JEPA's temporal hierarchy
H-JEPA occupies a unique position as the architectural ceiling of the JEPA family: while I-JEPA, V-JEPA, and other variants instantiate specific single-level JEPA modules, H-JEPA is the framework for composing them into a multi-scale system. Every JEPA variant can, in principle, serve as a "level" within an H-JEPA hierarchy.
Influence
H-JEPA has influenced subsequent work in several directions:
- V-JEPA (Bardes et al., 2024) can be viewed as a "one-and-a-half level" approximation of H-JEPA: it adds temporal prediction to I-JEPA but within a single-resolution framework.
- Hierarchical world models for robotics (e.g., TD-MPC2) increasingly adopt the principle of multi-level latent planning, though not all use the EMA JEPA formulation.
- Cognitive architecture research cites H-JEPA as a computationally grounded model of hierarchical predictive processing in the brain.
- The open-source Wiggins et al. implementation has enabled experimental validation and serves as a reference for researchers implementing hierarchical JEPA systems.
11. Summary
12. References
- LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. Technical report, Meta AI. arXiv: 2306.02572.
- Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.
- Bardes, A., Garrido, Q., Ponce, J., Chen, X., Rabbat, M., LeCun, Y., Assran, M., & Ballas, N. (2024). V-JEPA: Latent Video Prediction for Visual Representation Learning. Technical report, Meta AI.
- Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., Doersch, C., Pires, B. Á., Guo, Z. D., Azar, M. G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.
- Wiggins, J., et al. (2024). H-JEPA: An Open-Source Implementation of Hierarchical Joint Embedding Predictive Architecture. GitHub: https://github.com/jonwiggins/H-JEPA.
- Hafner, D., Lee, K.-H., Fischer, I., & Abbeel, P. (2022). Deep Hierarchical Planning from Pixels. NeurIPS 2022.
- Saxena, S., Kipf, T., et al. (2021). Clockwork Variational Autoencoders. NeurIPS 2021.
- Razavi, A., van den Oord, A., & Vinyals, O. (2019). Generating Diverse High-Fidelity Images with VQ-VAE-2. NeurIPS 2019.
- Kendall, A., Gal, Y., & Cipolla, R. (2018). Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics. CVPR 2018.
- Perez, E., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. (2018). FiLM: Visual Reasoning with a General Conditioning Layer. AAAI 2018.
- Chen, X. & He, K. (2021). Exploring Simple Siamese Representation Learning. CVPR 2021.
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weisenbrock, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.