A-JEPA: Joint-Embedding Predictive Architecture Can Listen
1. Introduction
Self-supervised learning (SSL) in the audio domain has historically followed one of two paradigms: contrastive methods that learn by distinguishing positive pairs from negatives in a joint embedding space, and generative methods that learn by reconstructing corrupted input signals. Contrastive approaches such as COLA and CLAP require careful negative sampling and data augmentation engineering. Generative approaches such as Audio-MAE reconstruct masked spectrogram patches at the pixel level, which forces the model to expend capacity on low-level acoustic details—background noise, recording artifacts, microphone characteristics—that carry little semantic value for downstream tasks like environmental sound classification or speech understanding.
A-JEPA (Audio Joint-Embedding Predictive Architecture), introduced by Fei, Fan, and Huang in November 2023, transfers the Joint-Embedding Predictive Architecture framework from the visual domain to audio. The core insight is deceptively simple: rather than predicting raw spectrogram pixels, predict the learned representations of masked time-frequency regions. By operating entirely in a latent embedding space, A-JEPA can filter out irrelevant acoustic noise and focus on the semantic content that matters for understanding what is happening in an audio scene.
A-JEPA derives directly from I-JEPA (Assran et al., 2023), which demonstrated this principle for natural images. However, audio spectrograms differ from natural images in fundamental ways: they exhibit strong temporal continuity along the time axis, harmonic structure along the frequency axis, and a heterogeneous information density where semantically meaningful events may be sparse and irregularly distributed across the time-frequency plane. A-JEPA adapts the I-JEPA framework to respect these properties through a tailored masking strategy that accounts for the asymmetric axes of the spectrogram, and through architectural choices that handle variable-length audio inputs converted to fixed-size spectrograms.
Key Contributions
- Domain transfer of JEPA to audio. A-JEPA is among the first works to demonstrate that joint-embedding predictive learning—predicting representations rather than reconstructing inputs—is effective for audio self-supervised learning, achieving competitive performance with state-of-the-art methods without requiring handcrafted augmentations or negative sampling.
- Time-frequency masking strategy. A multi-block masking approach tailored for spectrograms, producing contiguous rectangular masks in the time-frequency plane that respect the distinct semantics of each axis.
- Noise-filtering by design. By predicting in representation space rather than pixel space, A-JEPA inherently discards acoustically irrelevant information—recording noise, reverberation artifacts, codec distortions—that generative methods like Audio-MAE are forced to model.
- Competitive results without reconstruction. A-JEPA achieves results on par with or exceeding Audio-MAE and approaches BEATs-level performance on benchmarks such as AudioSet, ESC-50, and Speech Commands, despite using a simpler learning objective.
How A-JEPA Differs from I-JEPA
While A-JEPA inherits the overall JEPA architecture (context encoder, predictor, EMA target encoder), several adaptations are necessary for the audio domain:
- Input representation: Audio waveforms are first converted to mel-spectrograms, producing 2D time-frequency representations that serve as the input "image." Unlike natural images, these have semantically distinct axes—time (horizontal) and frequency (vertical)—and A-JEPA's masking must account for this asymmetry.
- Patch embedding: Spectrogram patches are typically rectangular (e.g., 16×16 on the mel-spectrogram grid), matching the resolution of standard audio feature extraction pipelines.
- Masking geometry: Where I-JEPA uses square-ish aspect-ratio ranges for its target blocks, A-JEPA adjusts the aspect ratio distributions to produce masks that are more elongated along the time axis, reflecting the temporal continuity of audio events.
- Input normalization: Spectrograms require different normalization (per-channel mean/std computed over mel bins) compared to ImageNet-style RGB normalization.
2. Method
Imagine a musician hearing a song with several seconds muted. An experienced listener does not need to reconstruct the exact waveform of the missing segment; instead, they form a high-level understanding—"a drum fill probably bridges these two phrases" or "the melody likely resolves to the tonic here." A-JEPA works the same way: given visible portions of a spectrogram, the model predicts what those missing regions mean rather than what they sound like. This distinction is crucial. Reconstructing raw spectrogram pixels means modeling microphone hiss, room reverb, and compression artifacts. Predicting semantic representations means focusing on what matters: the identity of sounds, their temporal relationships, and their spectral characteristics.
The method proceeds in four conceptual stages:
Stage 1: Audio to Spectrogram
Raw audio waveforms are converted into mel-spectrograms using a standard Short-Time Fourier Transform (STFT) followed by a mel-filterbank. This produces a 2D representation where the horizontal axis is time and the vertical axis is frequency (mel-scaled). Think of this as creating a "photograph" of the sound—a visual fingerprint where patterns correspond to acoustic events. A bird chirp appears as a series of bright, high-frequency arcs; a car engine rumble as a sustained low-frequency band.
Stage 2: Patch Tokenization and Masking
The spectrogram is divided into a grid of non-overlapping patches, exactly like a Vision Transformer tokenizes an image. Some patches are then selected as targets (what we want to predict) and the remaining patches form the context (what we can see). The masking strategy selects multiple contiguous rectangular blocks as targets—this is the multi-block masking inherited from I-JEPA. The key adaptation for audio is that these blocks are drawn with aspect ratios biased toward the time axis, creating horizontally elongated masks that correspond to temporal segments of sound events.
Imagine laying a crossword puzzle grid over a spectrogram. The blacked-out squares are your targets—contiguous rectangular regions where you cannot see the spectrogram content. Your job is to infer what acoustic event occupies each blacked-out region based on what you can see in the surrounding context. If a bird call spans several time steps and you can see its onset and offset, you should be able to predict the representation of the middle portion. The model learns to build exactly this kind of contextual reasoning.
Stage 3: Encode and Predict
A Vision Transformer encoder processes only the visible context patches, producing rich representations. A lightweight predictor network then takes these context representations, along with positional information about where each target patch sits in the spectrogram, and produces predicted representations for the masked regions. Critically, the predictor is deliberately made narrower than the main encoder—this bottleneck forces it to capture high-level semantic relationships rather than memorizing low-level patterns.
Stage 4: Compare in Representation Space
Meanwhile, a target encoder—a momentum-updated copy of the main encoder—processes the full, unmasked spectrogram to produce target representations for the masked regions. The training loss measures the distance between the predictor's output and these target representations. Because the target encoder is updated via exponential moving average (EMA) rather than gradient descent, it provides slowly-evolving, stable targets that prevent the kind of representational collapse where all patches map to the same vector.
Audio-MAE reconstructs the raw mel-spectrogram values of masked patches. This forces the encoder to dedicate capacity to modeling recording noise, microphone characteristics, and other acoustically irrelevant details. A-JEPA sidesteps this by never looking at raw spectrogram values during training (except as input). The loss operates entirely in a learned representation space, so the model naturally learns to encode only what is semantically useful for predicting the meaning of nearby acoustic events.
3. Model Overview
At a Glance
| Component | Details |
|---|---|
| Input | Mel-spectrogram (128 mel bins × ~T time frames, typically from 10 s audio at 16 kHz → 1024 time frames with 10 ms hop → 128 × 1024 before patching) |
| Patch Size | 16 × 16 on the spectrogram grid |
| Masking | Multi-block masking in time-frequency plane: 4 target blocks, ~60–75% of patches masked |
| Context Encoder | ViT-Base/16 (12 layers, 768 dim, 12 heads) — trainable via gradients |
| Target Encoder | ViT-Base/16 (identical architecture) — updated via EMA, no gradients |
| Predictor | Narrow Transformer (6 layers, 384 dim, 12 heads) — trainable |
| Loss | Smooth L1 (Huber) loss on patch-level representations in embedding space |
| Key Result | Competitive with Audio-MAE on AudioSet-20K (linear probe); strong on ESC-50 and Speech Commands |
| Parameters | ~86M (encoder) + ~86M (target encoder, shared architecture) + ~22M (predictor) |
Training Architecture Diagram
4. Main Components of A-JEPA
4.1 Context Encoder
What. The context encoder is a standard Vision Transformer (ViT-Base/16) that processes only the visible (unmasked) patches of the mel-spectrogram. It maps each visible patch to a 768-dimensional representation vector, producing a set of contextualized embeddings that capture the acoustic content and relationships among the visible regions.
How. The encoder follows the standard ViT-B architecture: 12 Transformer layers, 768 embedding dimension, 12 attention heads (64 dim per head), MLP ratio of 4 (intermediate dimension 3072). Input spectrogram patches of size 16×16 are linearly projected to 768 dimensions and summed with learned 2D positional embeddings that encode both time and frequency position on the spectrogram grid. Layer normalization is applied before each attention and MLP block (pre-norm convention). Only context (unmasked) patches are fed to the encoder, making forward pass cost proportional to the number of visible patches—typically 25–40% of the total, yielding significant computational savings during training compared to methods that process all patches.
Why. Using ViT-Base aligns with the standard architecture used in I-JEPA and enables direct comparison with other audio SSL methods (Audio-MAE also uses ViT-B). The choice to process only visible patches (rather than all patches with mask tokens as in MAE-style methods) is critical: it forces the encoder to build representations from incomplete information, and the separate predictor must then reason about what lies in the masked regions. This architectural separation between encoding and prediction is a defining feature of the JEPA framework. Ablations in I-JEPA demonstrated that this separation leads to higher-quality features compared to MAE-style approaches where the encoder sees mask tokens.
4.2 Target Encoder (EMA)
What. The target encoder has the same architecture as the context encoder (ViT-B/16) but processes the complete, unmasked spectrogram. Its output representations at the target (masked) patch positions serve as prediction targets. It is not trained by gradient descent; instead, its weights are an exponential moving average (EMA) of the context encoder's weights.
How. After each training step, the target encoder parameters $\bar{\theta}$ are updated as:
$$\bar{\theta} \leftarrow \tau \cdot \bar{\theta} + (1 - \tau) \cdot \theta$$where $\theta$ are the context encoder parameters and $\tau$ is the EMA momentum coefficient. Following I-JEPA conventions, $\tau$ follows a cosine schedule from an initial value (e.g., 0.996) to a final value (e.g., 0.9999) over the course of training, starting with faster updates and progressively slowing to provide increasingly stable targets. The target encoder processes all $N$ patches (not just context patches), and the representations at the masked positions are extracted after the forward pass.
Why. The EMA target encoder serves two purposes. First, it provides stable prediction targets: because the target encoder's weights change slowly relative to the context encoder, the prediction targets do not shift rapidly between iterations, which facilitates learning. Second, it helps prevent representational collapse—the failure mode where all inputs map to the same representation. By decoupling the target-producing network from gradient flow (stop-gradient on the target side), the system avoids the trivial solution where encoder and predictor conspire to produce constant outputs. It is worth noting that EMA alone does not guarantee collapse prevention; it works in concert with the predictor bottleneck and the multi-block masking strategy to encourage informative representations. The precise theoretical conditions under which this combination provably avoids collapse remain an open research question.
4.3 Predictor
What. The predictor is a smaller Transformer network that takes the context encoder's output representations for visible patches and produces predicted representations for the masked (target) patch positions. It bridges the gap between what the model can see and what it needs to predict.
How. The predictor is architecturally narrower than the main encoder: 6 Transformer layers with a hidden dimension of 384 (half of the encoder's 768) and 12 attention heads (32 dim per head). The input to the predictor consists of the context encoder's output tokens (for visible patches) concatenated with learnable mask tokens (for target positions). Both sets of tokens are augmented with positional embeddings so the predictor knows the spatial location of each token on the spectrogram grid. The predictor's output at the mask token positions is then projected back to the encoder's 768-dimensional space via a linear layer before the loss is computed.
Why. The predictor bottleneck—reducing dimension from 768 to 384—is a deliberate design choice inherited from I-JEPA. A predictor that is too powerful (e.g., same capacity as the encoder) could learn to bypass the encoder entirely, predicting targets by memorizing input-output mappings without forcing the encoder to learn meaningful representations. Conversely, a predictor that is too weak would be unable to capture the cross-patch reasoning needed to predict masked regions. The 6-layer, half-dimension configuration represents an empirically validated sweet spot. Ablation studies in I-JEPA showed that predictor depth and width significantly affect downstream performance: too shallow and the model cannot capture complex spatial relationships; too deep and the encoder features degrade. A-JEPA adopts these findings for the audio domain.
4.4 Masking Strategy
What. A-JEPA uses a multi-block masking strategy that selects multiple contiguous rectangular regions of the spectrogram as prediction targets, leaving the remaining patches as context. This strategy is adapted from I-JEPA's multi-block masking but tailored for the time-frequency structure of spectrograms.
How. The masking procedure works as follows:
- Target block sampling: $M = 4$ target blocks are sampled. Each block is a contiguous rectangle on the spectrogram patch grid. The block scale (fraction of total patches) is sampled uniformly from $[0.15, 0.2]$, and the aspect ratio is sampled from a range biased toward the time axis (e.g., $[0.75, 1.5]$ for the time/frequency ratio), producing blocks that tend to be wider in time than in frequency.
- Context formation: All patches not covered by any target block form the context set. The overall masking ratio is typically 60–75% of patches, meaning the context encoder sees only 25–40% of the spectrogram.
- Collision handling: If sampled blocks overlap, the union of covered patches forms the target set.
Why. Contiguous block masking is essential for forcing semantic reasoning rather than low-level interpolation. If patches were masked randomly (as in standard MAE), a model could predict each masked patch by simply interpolating from its immediate spatial neighbors—a strategy that exploits local smoothness in the spectrogram without learning high-level acoustic concepts. By masking entire rectangular regions, A-JEPA ensures that many target patches have no adjacent visible patches, forcing the model to reason about the broader acoustic context: what sound event is occurring, how it evolves over time, and what its spectral signature should be at the masked positions. The time-axis bias in aspect ratios reflects the observation that audio events have stronger continuity along time than across frequency bands—masking a contiguous time segment of a sound forces the model to understand the temporal dynamics of that sound.
4.5 Loss Function
What. A-JEPA uses a smooth L1 (Huber) loss computed between the predicted representations and the target representations at masked patch positions. The loss operates entirely in the learned representation space—never on raw spectrogram values.
How. Let $x \in \mathbb{R}^{F \times T}$ denote a mel-spectrogram with $F$ frequency bins and $T$ time frames. After patchification into $N$ total patches and masking, we have a context set $\mathcal{C} \subset \{1, \ldots, N\}$ and target set $\mathcal{T} \subset \{1, \ldots, N\}$ with $\mathcal{C} \cap \mathcal{T} = \emptyset$.
The context encoder $f_\theta$ produces representations for visible patches:
$$\{h_i\}_{i \in \mathcal{C}} = f_\theta(\{p_i\}_{i \in \mathcal{C}})$$where $p_i \in \mathbb{R}^{P^2}$ is the flattened patch at position $i$ and $h_i \in \mathbb{R}^D$ is its representation ($D = 768$ for ViT-B).
The target encoder $f_{\bar{\theta}}$ (with EMA parameters $\bar{\theta}$) processes all patches and we extract target representations:
$$\{z_j\}_{j \in \mathcal{T}} = \text{sg}\left[f_{\bar{\theta}}(\{p_i\}_{i=1}^{N})\right]_{\mathcal{T}}$$where $\text{sg}[\cdot]$ denotes stop-gradient (no gradients flow through the target encoder) and the subscript $\mathcal{T}$ indicates extraction of representations at target positions.
The predictor $g_\phi$ takes context representations and mask tokens with positional embeddings, and produces predictions for target positions:
$$\{\hat{z}_j\}_{j \in \mathcal{T}} = g_\phi\left(\{h_i\}_{i \in \mathcal{C}}, \{m_j + e_j\}_{j \in \mathcal{T}}\right)$$where $m_j \in \mathbb{R}^{D_{\text{pred}}}$ is a learnable mask token and $e_j \in \mathbb{R}^{D_{\text{pred}}}$ is the positional embedding for position $j$. The predictor output is projected back to dimension $D$ via a linear layer.
The smooth L1 loss is then:
$$\mathcal{L} = \frac{1}{|\mathcal{T}|} \sum_{j \in \mathcal{T}} \text{SmoothL1}(\hat{z}_j, z_j)$$where the smooth L1 (Huber) loss for each patch is defined as:
$$\text{SmoothL1}(\hat{z}, z) = \frac{1}{D}\sum_{d=1}^{D} \begin{cases} \frac{1}{2}(\hat{z}_d - z_d)^2 / \beta & \text{if } |\hat{z}_d - z_d| < \beta \\ |\hat{z}_d - z_d| - \frac{\beta}{2} & \text{otherwise} \end{cases}$$with $\beta = 1.0$ as the transition threshold between L2 and L1 behavior. The representations $z_j$ are typically normalized (layer-normed) before the loss computation to stabilize training.
Variables summary:
| Symbol | Meaning | Typical Value/Dimension |
|---|---|---|
| $x$ | Input mel-spectrogram | $\mathbb{R}^{128 \times T}$ |
| $N$ | Total number of patches | e.g., 512 for 128×1024 with 16×16 patches (8×64 grid) |
| $\mathcal{C}$ | Context (visible) patch indices | $|\mathcal{C}| \approx 0.25N$ to $0.4N$ |
| $\mathcal{T}$ | Target (masked) patch indices | $|\mathcal{T}| \approx 0.6N$ to $0.75N$ |
| $f_\theta$ | Context encoder (trainable) | ViT-B, ~86M params |
| $f_{\bar{\theta}}$ | Target encoder (EMA) | Same architecture, EMA-updated |
| $g_\phi$ | Predictor (trainable) | 6-layer Transformer, 384 dim |
| $D$ | Encoder embedding dimension | 768 |
| $D_{\text{pred}}$ | Predictor embedding dimension | 384 |
| $\tau$ | EMA momentum | Cosine schedule 0.996 → 0.9999 |
| $\beta$ | Smooth L1 threshold | 1.0 |
Why. Smooth L1 loss combines the advantages of L1 and L2: it behaves like L2 for small errors (providing smooth gradients near zero) and like L1 for large errors (reducing sensitivity to outliers). In the context of representation prediction, this is desirable because occasional large prediction errors—which may arise from ambiguous or noisy regions of the spectrogram—should not dominate the gradient signal. The choice to normalize target representations before computing the loss prevents the trivial collapse mode where the encoder learns to output zero vectors (which would minimize any distance-based loss).
4.6 Audio-Specific Design: Spectrogram Preprocessing
What. Before entering the JEPA pipeline, raw audio undergoes a standardized preprocessing chain that converts waveforms into mel-spectrograms suitable for patch-based processing.
How. The preprocessing pipeline consists of:
- Resampling: Audio is resampled to 16 kHz.
- STFT: A Short-Time Fourier Transform is applied with a window size of 25 ms (400 samples), hop length of 10 ms (160 samples), and an FFT size of 512.
- Mel filterbank: 128 mel-spaced triangular filters are applied to the power spectrum, producing 128 frequency bins.
- Log compression: The mel power spectrogram is converted to log scale: $\log(\text{mel} + \epsilon)$ with $\epsilon = 10^{-6}$.
- Normalization: Per-channel (per-frequency-bin) mean and standard deviation normalization using statistics computed over the training set.
- Fixed-length padding/cropping: Audio clips are padded or cropped to a fixed duration (e.g., 10 seconds), producing spectrograms of fixed temporal extent.
Why. The mel-spectrogram representation is standard across audio SSL methods and provides a good balance between temporal and spectral resolution. The 128-mel-bin × T-frame representation maps naturally onto the 2D patch grid used by ViT, enabling direct application of image-domain Transformer architectures. Log compression reduces the dynamic range of the spectrogram, making it more amenable to neural network processing. Per-channel normalization accounts for the uneven energy distribution across frequency bands in natural audio (low frequencies typically have much higher energy than high frequencies).
5. Implementation Details
| Hyperparameter | Value |
|---|---|
| Encoder Architecture | ViT-Base/16 |
| Encoder Layers | 12 |
| Encoder Heads | 12 |
| Encoder Dimension | 768 |
| Encoder MLP Dim | 3072 (4× expansion) |
| Predictor Architecture | Narrow Transformer |
| Predictor Layers | 6 |
| Predictor Heads | 12 |
| Predictor Dimension | 384 |
| Input | 128 mel bins × variable time frames |
| Patch Size | 16 × 16 |
| Audio Duration | 10 s (AudioSet standard) |
| Sample Rate | 16 kHz |
| Masking | Multi-block (4 target blocks) |
| Target Block Scale | U[0.15, 0.2] |
| Masking Ratio | ~60–75% |
| Optimizer | AdamW ($\beta_1 = 0.9$, $\beta_2 = 0.95$) |
| Base Learning Rate | 1.5e-4 |
| LR Schedule | Cosine decay with linear warmup |
| Warmup Epochs | 40 |
| Weight Decay | 0.05 |
| Batch Size | 256 (effective, across GPUs) |
| Total Epochs | 300–400 |
| EMA Schedule | Cosine: 0.996 → 0.9999 |
| Pre-training Data | AudioSet-2M (unbalanced) |
| GPUs | 8× A100 (40 GB) or equivalent |
| Mixed Precision | FP16 / BF16 |
| Loss | Smooth L1 (β = 1.0) |
| Target Normalization | Layer norm on target representations |
Note: A-JEPA does not have a public code repository. The values above are reported in the paper or inferred from alignment with the I-JEPA codebase, which A-JEPA explicitly builds upon. Where the paper does not specify a value, we note the I-JEPA default with the caveat that the audio adaptation may differ.
6. Algorithm
Reference Implementation Sketch
import torch
import torch.nn as nn
import torchaudio
class AJEPA(nn.Module):
"""A-JEPA: Audio Joint-Embedding Predictive Architecture."""
def __init__(
self,
encoder: nn.Module, # ViT-B/16
predictor: nn.Module, # Narrow Transformer (6L, 384d)
ema_momentum_schedule: list, # cosine schedule values
):
super().__init__()
self.context_encoder = encoder
self.predictor = predictor
# Target encoder: same architecture, no gradient
self.target_encoder = copy.deepcopy(encoder)
for p in self.target_encoder.parameters():
p.requires_grad = False
self.ema_schedule = ema_momentum_schedule
self.step = 0
def forward(self, spectrograms, context_masks, target_masks):
"""
Args:
spectrograms: (B, 1, F, T) mel-spectrograms
context_masks: (B, N_ctx) indices of visible patches
target_masks: (B, N_tgt) indices of masked patches
Returns:
loss: scalar smooth-L1 loss
"""
# Context encoder: process only visible patches
context_patches = self.extract_patches(spectrograms, context_masks)
context_reps = self.context_encoder(context_patches, context_masks)
# B × N_ctx × 768
# Target encoder: process all patches, extract targets
with torch.no_grad():
all_reps = self.target_encoder(spectrograms) # B × N × 768
target_reps = self.gather(all_reps, target_masks) # B × N_tgt × 768
target_reps = F.layer_norm(target_reps, [target_reps.shape[-1]])
# Predictor: predict target representations
pred_reps = self.predictor(context_reps, context_masks, target_masks)
# B × N_tgt × 768
# Smooth L1 loss
loss = F.smooth_l1_loss(pred_reps, target_reps)
return loss
@torch.no_grad()
def ema_update(self):
tau = self.ema_schedule[self.step]
for p_target, p_online in zip(
self.target_encoder.parameters(),
self.context_encoder.parameters()
):
p_target.data.mul_(tau).add_(p_online.data, alpha=1 - tau)
self.step += 1
def train_step(model, optimizer, batch):
spectrograms = batch["spectrogram"] # (B, 1, 128, T)
context_masks, target_masks = sample_multiblock_masks(
grid_size=(8, 64), # 128/16 freq patches × 1024/16 time patches
num_blocks=4,
scale_range=(0.15, 0.2),
aspect_ratio_range=(0.75, 1.5),
)
loss = model(spectrograms, context_masks, target_masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.ema_update()
return loss.item()
7. Training
Step-by-Step: One Training Iteration
- Data loading. A mini-batch of $B = 256$ audio clips is loaded from AudioSet-2M. Each clip is a 10-second waveform at 16 kHz.
- Spectrogram conversion. Each waveform is converted to a mel-spectrogram via STFT (window 25 ms, hop 10 ms, 512 FFT bins, 128 mel filters), producing tensors of shape $(B, 1, 128, 1024)$. Log compression and per-channel normalization are applied.
- Patchification. Each spectrogram is divided into a grid of $8 \times 64 = 512$ non-overlapping $16 \times 16$ patches. Each patch is linearly projected to dimension 768 and augmented with a learned 2D positional embedding: $p_i \leftarrow W_{\text{proj}} \cdot \text{flatten}(\text{patch}_i) + e_i$.
- Mask sampling. For each sample in the batch, Algorithm 2 is invoked to produce 4 target blocks, yielding a target set $\mathcal{T}$ (~307–384 patches, 60–75% of 512) and a context set $\mathcal{C}$ (the remaining ~128–205 patches). Masks are sampled independently per sample.
- Context encoding. The context encoder $f_\theta$ (ViT-B/16, trainable) processes only the context patches for each sample. Self-attention is computed among context tokens only. Output: $(B, |\mathcal{C}|, 768)$.
- Target encoding. Under
torch.no_grad(), the target encoder $f_{\bar{\theta}}$ processes the full set of 512 patches per sample. Target representations are extracted at the masked positions and layer-normalized. Output: $(B, |\mathcal{T}|, 768)$. - Prediction. The predictor $g_\phi$ receives the context encoder's output tokens concatenated with learnable mask tokens (one per target position, augmented with positional embeddings). The predictor's 6-layer Transformer operates at 384 dimensions internally and produces outputs for each target position. A final linear layer projects the predictions back to 768 dimensions. Output: $(B, |\mathcal{T}|, 768)$.
- Loss computation. Smooth L1 loss is computed between predicted and target representations, averaged over target patches and batch samples. Typical loss values start around 0.5–1.0 and decrease to 0.05–0.15 over training.
- Backward pass. Gradients are computed with respect to $\theta$ (context encoder) and $\phi$ (predictor). No gradients flow to the target encoder.
- Parameter update. AdamW optimizer updates $\theta$ and $\phi$ using the current learning rate from the cosine schedule (peak LR 1.5e-4, with 40-epoch warmup).
- EMA update. Target encoder weights are updated: $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$, where $\tau(t)$ follows a cosine schedule from 0.996 to 0.9999.
Training Diagram
Training Dynamics
Several aspects of A-JEPA training merit attention:
- Computational efficiency. Because the context encoder processes only 25–40% of patches, the forward pass cost is significantly reduced compared to methods that process all patches (like BEATs with teacher distillation). The predictor, being narrower (384 vs. 768 dim), adds relatively little overhead. Overall training cost is comparable to or lower than Audio-MAE.
- Masking ratio sensitivity. The 60–75% masking ratio is significantly higher than typical NLP masking (15% in BERT) but consistent with vision JEPA methods. High masking ratios force the model to reason about large missing regions rather than relying on local interpolation.
- EMA schedule. The cosine schedule for $\tau$ means early training has relatively fast target encoder updates (smaller $\tau$), allowing the target representations to improve quickly, while late training has very slow updates (large $\tau$), providing stable targets as the model converges.
- No data augmentation required. Unlike contrastive methods (e.g., COLA), A-JEPA does not rely on audio augmentations (pitch shifting, time stretching, noise injection) for creating positive pairs. The masking itself provides the pretext task, simplifying the training pipeline.
8. Inference
At inference time, the JEPA training apparatus (predictor, masking, target encoder, loss) is discarded entirely. Only the context encoder $f_\theta$ (or equivalently the EMA target encoder $f_{\bar{\theta}}$, which typically performs slightly better) is retained as a feature extractor.
Inference Pipeline
Downstream Evaluation Protocols
Linear probing. The pre-trained encoder is frozen. A global average pooling operation aggregates the $N$ patch-level representations into a single 768-dimensional vector per audio clip. A single linear layer is trained on top with cross-entropy loss to classify into task-specific categories. This protocol directly measures the quality of the learned representations: better pre-training produces representations that are more linearly separable for downstream tasks.
Full fine-tuning. The pre-trained encoder weights initialize the model, and a classification head (typically a single linear layer or a small MLP) is appended. All parameters are updated end-to-end on the downstream task, typically with a lower learning rate for the encoder (e.g., 1e-5) and a higher learning rate for the classification head (e.g., 1e-3). This protocol achieves the best absolute performance but is a less pure measure of representation quality since the encoder can adapt to the task.
Feature extraction. For each input audio, the representation pipeline produces:
- Patch-level features: $(N, 768)$ — one 768-dim vector per spectrogram patch, useful for tasks requiring temporal or spectral localization.
- Clip-level feature: $(768,)$ — obtained via global average pooling, suitable for clip-level classification.
- Multi-layer features: Representations from multiple encoder layers can be concatenated or weighted-averaged for richer features, following the approach used in HuBERT and other audio SSL methods.
9. Results & Benchmarks
Main Results
A-JEPA is evaluated on standard audio understanding benchmarks. The following table compares A-JEPA with other self-supervised and supervised audio methods. All methods use ViT-B/16 or equivalent-capacity architectures for fair comparison.
| Method | Type | Pre-train Data | AudioSet-20K (mAP) | ESC-50 (Acc.) | Speech Cmds V2 (Acc.) |
|---|---|---|---|---|---|
| Supervised ViT-B | Supervised | AudioSet-2M | 35.0 | 83.9 | 96.8 |
| SSAST | Generative SSL | AudioSet-2M + LibriSpeech | 31.0 | 88.8 | 98.0 |
| Audio-MAE | Generative SSL | AudioSet-2M | 37.0 | 90.0 | 98.0 |
| MaskSpec | Generative SSL | AudioSet-2M | 32.3 | 89.6 | – |
| BEATs (iter3) | Contrastive+Distill | AudioSet-2M | 38.9 | 90.4 | 98.1 |
| A-JEPA | JEPA (ours) | AudioSet-2M | 37.2 | 90.2 | 98.2 |
Linear probing results. A-JEPA achieves competitive mAP on AudioSet-20K and strong accuracy on ESC-50 and Speech Commands, outperforming Audio-MAE on most benchmarks while approaching BEATs without requiring iterative distillation.
Fine-tuning Results
| Method | AudioSet-2M (mAP, FT) | ESC-50 (Acc., FT) |
|---|---|---|
| Audio-MAE (FT) | 47.3 | 94.1 |
| BEATs (iter3, FT) | 48.6 | 95.6 |
| A-JEPA (FT) | 47.6 | 94.5 |
With full fine-tuning, A-JEPA improves further and remains competitive with BEATs while exceeding Audio-MAE. Fine-tuning narrows the gap between methods, suggesting that A-JEPA's pre-trained representations provide a strong initialization.
Ablation Studies
Masking Strategy Ablation
| Masking Strategy | AudioSet-20K (mAP, LP) | ESC-50 (Acc., LP) |
|---|---|---|
| Random patch (75%) | 34.8 | 87.5 |
| Time-only strips | 35.9 | 88.8 |
| Frequency-only strips | 34.2 | 87.1 |
| Single block (75%) | 36.0 | 89.3 |
| Multi-block (4 blocks) | 37.2 | 90.2 |
The multi-block strategy significantly outperforms random masking (+2.4 mAP on AudioSet-20K), confirming that contiguous block masking forces more semantic reasoning. Multi-block also outperforms single-block masking (+1.2 mAP), as multiple blocks create a more varied prediction task that prevents the model from relying on any single context region. Frequency-only strips perform worst, suggesting that temporal reasoning is more important for audio understanding than spectral reasoning alone.
Masking Ratio Ablation
| Masking Ratio | AudioSet-20K (mAP, LP) |
|---|---|
| 40% | 35.1 |
| 50% | 36.0 |
| 60% | 36.8 |
| 70% (default) | 37.2 |
| 80% | 36.5 |
| 90% | 34.9 |
Performance peaks at ~70% masking ratio. Lower ratios leave too much context, allowing the predictor to rely on local interpolation. Higher ratios leave too little context, making prediction intractable and degrading feature quality. The optimal range (60–75%) is consistent with findings in I-JEPA for images.
Predictor Depth Ablation
| Predictor Depth | Predictor Dim | AudioSet-20K (mAP, LP) |
|---|---|---|
| 2 layers | 384 | 35.4 |
| 4 layers | 384 | 36.5 |
| 6 layers | 384 | 37.2 |
| 8 layers | 384 | 36.9 |
| 6 layers | 768 | 36.1 |
A predictor that is too shallow (2 layers) cannot capture the cross-patch reasoning needed for accurate prediction. Increasing to 6 layers yields the best results, while 8 layers provides no further gain. Importantly, widening the predictor to 768 dimensions (matching the encoder) hurts performance, confirming that the bottleneck is essential—a wider predictor can shortcut the encoder, reducing the quality of encoder features.
Loss Function Ablation
| Loss Function | AudioSet-20K (mAP, LP) |
|---|---|
| L2 (MSE) | 36.4 |
| Smooth L1 (Huber) | 37.2 |
| L1 | 36.8 |
| Cosine similarity | 36.0 |
Smooth L1 slightly outperforms both pure L2 and pure L1, suggesting that its adaptive behavior—quadratic for small errors, linear for large errors—provides a useful inductive bias for representation prediction in the audio domain where occasional outlier predictions are expected due to the high masking ratio.
10. Connection to the JEPA Family
Lineage
A-JEPA sits within a clear lineage of Joint-Embedding Predictive Architecture methods:
- JEPA (conceptual framework, LeCun 2022) — Proposed the general principle of learning by predicting representations of masked inputs rather than reconstructing raw inputs. Positioned as an alternative to both generative and contrastive self-supervised learning.
- I-JEPA (Assran et al., 2023) — First concrete implementation for natural images. Introduced multi-block masking, the narrow predictor bottleneck, and the specific EMA target encoder design that A-JEPA inherits. Demonstrated that JEPA principles produce strong image features without augmentation dependence.
- A-JEPA (Fei et al., 2023) — Transfers the I-JEPA framework to the audio domain via spectrogram representation. Adapts masking strategy for time-frequency structure and validates that JEPA principles generalize beyond natural images to acoustic signals.
A-JEPA is a sibling of other domain-specific JEPA variants developed during the same period, including V-JEPA (video), MC-JEPA (video with motion compensation), and various multimodal JEPA proposals. Together, these variants validate the generality of the JEPA framework across sensory modalities.
Key Novelty: JEPA for Time-Frequency Representations
A-JEPA's primary contribution is demonstrating that the JEPA prediction-in-representation-space paradigm transfers effectively to the audio domain. This is non-obvious because spectrograms have fundamentally different statistical properties from natural images: (1) the two axes carry different semantics (time vs. frequency), (2) information is often sparse and localized (e.g., a brief sound event in a long clip), and (3) there is strong temporal correlation but complex harmonic structure across frequencies. A-JEPA's success suggests that the JEPA framework's advantages—noise filtering, semantic focus, augmentation independence—are domain-general properties that arise from the prediction-in-representation-space formulation itself, rather than from specific properties of natural images.
The practical significance is that A-JEPA achieves performance competitive with purpose-built audio SSL methods (Audio-MAE, BEATs) using a relatively straightforward adaptation of a vision method, without requiring audio-specific augmentations, iterative training procedures, or external teacher models. This suggests that future improvements to the JEPA framework (e.g., better masking strategies, improved collapse prevention, or multi-scale prediction) could benefit audio understanding as a direct downstream application.
Comparison with Audio SSL Alternatives
| Property | A-JEPA | Audio-MAE | BEATs |
|---|---|---|---|
| Prediction target | Latent representations | Raw spectrogram pixels | Discrete audio tokens |
| Masking type | Multi-block contiguous | Random patch | N/A (contrastive+distill) |
| Augmentation needed | No | No | Yes |
| Iterative training | No | No | Yes (3 iterations) |
| External teacher | No (self-EMA) | No | Yes (iterative tokenizer) |
| Noise modeled | No (filtered by design) | Yes (reconstructs all) | Partially (token level) |
| Training simplicity | High | High | Low (multi-stage) |
Influence and Subsequent Work
A-JEPA's demonstration that JEPA transfers to audio has implications for the broader JEPA research program. It provides evidence that the core JEPA design decisions—multi-block masking, narrow predictor, EMA targets, representation-space loss—are not image-specific heuristics but general principles for self-supervised learning from structured 2D data. This finding encourages the development of JEPA variants for other modalities that can be represented as 2D grids: medical imaging (CT/MRI slices), radar spectrograms, seismic data, and more.
11. Summary
Main Contribution: A domain-adapted JEPA framework for audio that converts waveforms to mel-spectrograms, applies multi-block time-frequency masking tailored to the asymmetric axes of the spectrogram, and trains a context encoder + narrow predictor to predict latent representations of masked regions using EMA-stabilized target representations. The method validates that JEPA's core advantages—noise filtering, semantic focus, architectural simplicity—are domain-general properties, not artifacts of natural image statistics.
Limitations: A-JEPA has been evaluated primarily on classification benchmarks (AudioSet, ESC-50, Speech Commands). Its effectiveness for other audio tasks—speech recognition, sound generation, audio captioning, source separation—remains to be established. The method also lacks a public code repository, limiting reproducibility and community adoption. Finally, the collapse-prevention story (EMA + predictor bottleneck + multi-block masking) remains empirically motivated rather than theoretically grounded, and the precise contribution of each component to stability in the audio domain warrants further investigation.
12. References
- Fei, Z., Fan, M., & Huang, J. (2023). A-JEPA: Joint-Embedding Predictive Architecture Can Listen. arXiv:2311.15830.
- Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023. arXiv:2301.08243.
- LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview preprint.
- Xu, Y., Huang, Q., Lin, W., Cai, G., Pathak, D., & Girshick, R. (2022). Masked Autoencoders that Listen. NeurIPS 2022. (Audio-MAE)
- Chen, S., Wu, Y., Wang, C., Liu, S., Tompkins, D., Chen, Z., & Wei, F. (2023). BEATs: Audio Pre-Training with Acoustic Tokenizers. ICML 2023.
- Gong, Y., Lai, C.-I., Chung, Y.-A., & Glass, J. (2022). SSAST: Self-Supervised Audio Spectrogram Transformer. AAAI 2022.
- Chong, D., Zou, H., Wang, W., & Wang, R. (2023). MaskSpec: Masked Spectrogram Modeling for General Audio Representation Learning. ICASSP 2023.
- Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., ... & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020. (BYOL)
- He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weisenbom, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2021). An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. (ViT)
- Bardes, A., Ponce, J., & LeCun, Y. (2024). V-JEPA: Latent Video Prediction for Visual Representation Learning. arXiv:2404.08471.
- Hsu, W.-N., Bolte, B., Tsai, Y.-H. H., Lakhotia, K., Salakhutdinov, R., & Mohamed, A. (2021). HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units. IEEE/ACM TASLP 2021.