AuthorsAbdelfattah, Alahi
Date2024-10
CategorySkeleton / Action
Derives fromI-JEPA

S-JEPA: Skeletal Joint Embedding Predictive Architecture

Variant: S-JEPA (Skeletal JEPA)  |  Domain: 3D Skeleton Action Recognition  |  Date: 2024-10  |  Authors: Abdelfattah, Alahi  |  Derives from: I-JEPA
Repository: github.com/Moo-osama/S-JEPA

1. Introduction

Self-supervised learning for skeleton-based action recognition has followed two dominant paradigms: contrastive learning and generative masked modeling. Contrastive methods such as 3s-CrosSCLR, AimCLR, and HiCLR construct positive and negative skeleton pairs through augmentation pipelines and learn representations by pulling similar pairs together while pushing dissimilar ones apart. These approaches require careful augmentation design, large batch sizes or memory banks for sufficient negatives, and are sensitive to the choice of data transformations—particularly problematic for skeleton data, where aggressive augmentations can destroy the biomechanical plausibility of a pose sequence. Generative masked approaches such as SkeletonMAE and MAMP mask portions of the skeleton input and reconstruct raw 3D joint coordinates. While these avoid the negative-pair problem, they force the encoder to devote representational capacity to low-level spatial details—exact joint positions, bone lengths, and coordinate noise—that are largely irrelevant for downstream semantic tasks like action classification.

S-JEPA (Skeletal Joint Embedding Predictive Architecture) addresses both limitations by transposing the JEPA framework, originally developed for images in I-JEPA, to the skeleton domain. The core insight is simple but powerful: instead of predicting raw joint coordinates (input space), predict the latent representations of masked joints as produced by an exponential moving average (EMA) target encoder (representation space). This latent prediction objective discards unpredictable low-level detail and focuses the encoder on learning abstract, semantically meaningful features of human motion.

S-JEPA introduces three key innovations beyond the direct transposition of I-JEPA:

  1. Motion-aware spatial masking. Rather than masking random blocks (as in I-JEPA for images), S-JEPA computes per-joint motion magnitudes across the temporal sequence and biases masking toward high-motion joints. With a masking ratio of $r = 0.9$, the model retains only a handful of low-motion joints (e.g., torso, hips) and must predict the latent representations of the most informative, action-discriminative joints (e.g., hands, feet). This creates a harder and more semantically informative prediction task than uniform random masking.
  2. Cross-entropy loss with centering and sharpening. While I-JEPA uses an $L_2$ loss between predicted and target representations, S-JEPA treats encoder outputs as logits over a learned feature space, converts them to probability distributions via temperature-scaled softmax, and minimizes cross-entropy. A centering mechanism (running mean subtraction on target outputs) and asymmetric temperature sharpening prevent representational collapse without requiring negative pairs or explicit variance regularization.
  3. Geometric view transformations. S-JEPA applies random 3D geometric transformations (rotation, scaling, translation, reflection) to generate diverse views of each skeleton sequence. The view encoder receives one augmented view while the target encoder receives another, encouraging the learned representations to be invariant to viewpoint and scale changes—critical for skeleton-based recognition where camera placement varies across datasets.

Compared to I-JEPA, S-JEPA operates on structured spatiotemporal graph data (skeleton sequences) rather than 2D image patches, replaces block masking with a motion-informed joint selection strategy, and substitutes the $L_2$ reconstruction loss with a distributional cross-entropy objective. These changes are not mere cosmetic adaptations—they reflect fundamental differences between the spatial locality of image patches and the semantic heterogeneity of skeleton joints. Evaluated on the standard NTU RGB+D 60, NTU RGB+D 120, and PKU-MMD benchmarks, S-JEPA achieves competitive or superior performance relative to both contrastive and generative masked methods across linear probing, fine-tuning, and semi-supervised evaluation protocols.

2. Method

Understanding S-JEPA requires thinking about three ideas in sequence: what the model sees, what it must predict, and what it predicts about.

Intuition: The Choreographer's Blind Spot. Imagine a choreographer watching a dancer through a window, but the glass is frosted except for a narrow slit showing only the dancer's torso and one shoulder. The choreographer must infer what the dancer's arms and legs are doing based solely on the torso's tilt, rhythm, and momentum. S-JEPA works the same way: the model sees a few low-motion joints (the "slit") and must predict abstract representations of the high-motion limbs. The prediction is not about exact positions—the choreographer doesn't need to know the hand is at coordinate (0.3, 1.2, 0.7)—but about the type and quality of movement: "the right arm is sweeping upward in an arc consistent with a throwing motion."

Step 1: View Creation. Given a skeleton sequence of $T$ frames and $N$ joints, S-JEPA applies two independent random geometric transformations (rotation, scaling, translation, reflection) to produce two views of the same action. View 1 is processed by the view encoder (also called the online or student encoder); View 2 is processed by the target encoder. Using different views encourages the learned representations to capture the action's semantics rather than view-specific spatial details.

Step 2: Motion-Aware Masking. S-JEPA computes the average displacement of each joint across the temporal sequence. Joints with larger total motion receive higher masking probabilities. A set of $\lfloor r \cdot N \rfloor$ joints is sampled (without replacement) according to these motion-weighted probabilities, where $r = 0.9$. The masking is spatial: once a joint is selected for masking, it is masked across all $T$ frames. This leaves only a few stable, low-motion joints visible to the view encoder.

Step 3: Encoding and Prediction. The view encoder processes only the visible joint tokens (typically 2–3 joints across all frames). The predictor network takes these encoded representations plus learnable mask tokens and predicts latent representations for the masked joints. Simultaneously, the target encoder—a momentum-updated copy of the view encoder that receives no gradients—processes the complete second view (all $N$ joints, all $T$ frames) to produce target representations.

Intuition: Why Predict in Latent Space? Consider two skeleton sequences of a "waving" action. In one, the person waves fast with large amplitude; in the other, a different person waves slowly with small amplitude. The raw 3D coordinates are very different, but the semantic content is the same. An $L_2$ loss on raw coordinates would penalize the model for not matching exact positions, encouraging memorization of individual skeletons rather than understanding of actions. By predicting in the target encoder's latent space, S-JEPA only needs to match the abstract semantic representation, which should be similar for both waving sequences. The target encoder's EMA update ensures these target representations evolve smoothly and consistently.

Step 4: Loss and Stability. The predicted and target representations are converted to probability distributions via temperature-scaled softmax. The cross-entropy between the target distribution (sharpened with a low temperature $\tau_t$) and the predicted distribution (softer, with higher temperature $\tau_s$) is minimized. A centering vector—the running mean of target outputs—is subtracted before the target softmax to prevent collapse to a uniform or degenerate distribution. Gradients flow only through the view encoder and predictor; the target encoder is updated exclusively via EMA.

3. Model Overview

At-a-Glance

ComponentDetails
Input3D skeleton sequences: $T$ frames $\times$ $N$ joints $\times$ 3 channels (x, y, z)
MaskingMotion-aware spatial masking; $r = 0.9$ (90% of joints masked across all frames); bias toward high-motion joints
View EncoderSkeleton Transformer; processes visible (unmasked) joint tokens; trainable via backpropagation
Target EncoderSame architecture as view encoder; updated via EMA; no gradient; processes full skeleton view
PredictorLightweight Transformer; takes view encoder output + mask tokens; predicts latent representations of masked joints
LossCross-entropy on sharpened/centered probability distributions (not $L_2$)
Key InnovationMotion-aware masking + distributional loss with centering/sharpening for skeleton JEPA
BenchmarksNTU RGB+D 60, NTU RGB+D 120, PKU-MMD — competitive with contrastive and generative SSL methods

Training Architecture Diagram

S-JEPA Training Architecture Skeleton Input T×N×3 Geo. Aug. τ₁ Geo. Aug. τ₂ View 1 View 2 Motion-Aware Masking (r=0.9) Visible Joints T×N_v×3 View Encoder f_θ (trainable) B×(N_v·T)×D Predictor p_φ (trainable) B×(N_m·T)×D Mask Tokens Predicted Repr. softmax(ẑ/τ_s) B×(N_m·T)×K Target Encoder g_ξ (EMA, frozen) B×(N·T)×D Full view (all N joints) Target Repr. softmax((z−c)/τ_t) B×(N_m·T)×K Select masked positions CE Loss H(q, p) EMA ∇ → f_θ, p_φ Centering (EMA of z̄) stop-gradient Trainable Frozen (EMA) EMA Update N_v = visible joints | N_m = masked joints | K = output dim
Figure 1. S-JEPA training architecture. The skeleton input is augmented into two geometric views. View 1 is motion-aware masked (r=0.9) and processed by the trainable view encoder and predictor. View 2 is processed in full by the EMA target encoder. Cross-entropy loss is computed between centered/sharpened target distributions and predicted distributions at masked positions. Gradients flow only through the online path (green solid lines).

4. Main Components of S-JEPA

4.1 View Encoder ($f_\theta$)

WHAT: The view encoder is a Transformer that maps visible skeleton joint tokens to $D$-dimensional latent representations. It is the primary representation learning module and the component used at inference time.

HOW: Each visible joint at each frame is embedded as a token. Given a skeleton sequence $\mathbf{X} \in \mathbb{R}^{T \times N \times 3}$ and a set of visible joint indices $\mathcal{V} \subset \{1, \ldots, N\}$ (with $|\mathcal{V}| = N_v = \lfloor (1-r) \cdot N \rfloor$), the input tokens are:

$$e_j^t = \text{Linear}(\mathbf{p}_j^t) + \mathbf{E}_{\text{joint}}[j] + \mathbf{E}_{\text{time}}[t], \quad j \in \mathcal{V}, \; t \in \{1, \ldots, T\}$$

where $\mathbf{p}_j^t \in \mathbb{R}^3$ is the 3D position of joint $j$ at frame $t$, $\text{Linear}: \mathbb{R}^3 \to \mathbb{R}^D$ is a learnable linear projection, $\mathbf{E}_{\text{joint}} \in \mathbb{R}^{N \times D}$ is a learnable joint-identity embedding, and $\mathbf{E}_{\text{time}} \in \mathbb{R}^{T \times D}$ is a learnable temporal position embedding. The encoder processes the sequence of $N_v \cdot T$ tokens through $L$ Transformer layers with multi-head self-attention and MLP blocks:

$$\mathbf{H}^{(\ell)} = \text{TransformerBlock}^{(\ell)}(\mathbf{H}^{(\ell-1)}), \quad \ell = 1, \ldots, L$$

where $\mathbf{H}^{(0)} \in \mathbb{R}^{(N_v \cdot T) \times D}$ is the set of input embeddings. The output $\mathbf{H}^{(L)}$ provides the latent representations of visible joints.

WHY: Processing only visible tokens (rather than all tokens with masking indicators) provides a computational advantage proportional to the masking ratio—at $r = 0.9$, the encoder processes only 10% of tokens. This follows the efficient masking strategy from MAE and I-JEPA. The joint-identity and temporal position embeddings allow the Transformer to distinguish between different body parts and temporal positions without relying on skeleton graph topology explicitly, providing a more flexible and learnable spatial encoding than fixed graph adjacency.

4.2 Target Encoder ($g_\xi$)

WHAT: The target encoder has identical architecture to the view encoder but differs in three critical ways: (1) it processes the full skeleton view (all $N$ joints, all $T$ frames) from the second geometric augmentation; (2) it receives no gradients (stop-gradient); and (3) its parameters $\xi$ are updated via exponential moving average of the view encoder parameters $\theta$.

HOW: The EMA update at each training step is:

$$\xi \leftarrow \tau_{\text{ema}} \cdot \xi + (1 - \tau_{\text{ema}}) \cdot \theta$$

where $\tau_{\text{ema}} \in [0, 1)$ follows a cosine schedule from an initial value $\tau_0$ (e.g., 0.996) to a final value approaching 1.0:

$$\tau_{\text{ema}}(t) = 1 - (1 - \tau_0) \cdot \frac{1 + \cos(\pi t / T_{\max})}{2}$$

The target encoder produces representations $\mathbf{Z}^{\text{tgt}} \in \mathbb{R}^{(N \cdot T) \times D}$ for all joint-frame tokens. Only the representations at masked positions $\mathcal{M}$ are used as prediction targets.

WHY: The EMA target encoder serves as a slowly evolving representation anchor. Without it, the system could trivially collapse—both encoder and predictor could learn to output a constant vector regardless of input, achieving zero loss. The EMA mechanism, inherited from BYOL and refined in I-JEPA, creates an asymmetry: the target representations change slowly and smoothly, providing stable supervision for the online path. The cosine schedule for $\tau_{\text{ema}}$ starts with faster target updates (enabling the target to incorporate early learning signals) and gradually slows updates to near-identity (providing increasingly stable targets as training matures). The stop-gradient on the target path is essential: without it, the gradient signal would flow through both paths and the asymmetry that prevents collapse would disappear.

4.3 Predictor ($p_\phi$)

WHAT: The predictor is a lightweight Transformer that takes the view encoder's output for visible tokens, along with learnable mask tokens at masked positions, and produces predicted latent representations for the masked joints.

HOW: The predictor assembles a full sequence of $N \cdot T$ tokens by combining the view encoder output (at visible positions) with learnable mask tokens $\mathbf{m} \in \mathbb{R}^D$ (at masked positions), augmented with the corresponding joint-identity and temporal position embeddings:

$$\tilde{e}_j^t = \begin{cases} f_\theta(\mathbf{X}_{\text{vis}})_j^t + \mathbf{E}_{\text{joint}}[j] + \mathbf{E}_{\text{time}}[t] & \text{if } j \in \mathcal{V} \\ \mathbf{m} + \mathbf{E}_{\text{joint}}[j] + \mathbf{E}_{\text{time}}[t] & \text{if } j \in \mathcal{M} \end{cases}$$

The predictor processes these $N \cdot T$ tokens through $L_p$ Transformer layers (with $L_p < L$, typically $L_p \approx L/2$) and outputs predictions $\hat{\mathbf{Z}} \in \mathbb{R}^{(N \cdot T) \times D}$, from which only the masked positions are extracted for the loss.

WHY: The predictor acts as a capacity bottleneck. It is deliberately shallower and potentially narrower than the encoder, preventing a "shortcut" where the predictor simply copies or memorizes the target representations. This forces the view encoder to produce rich, informative representations that a simple predictor can map to target representations. The use of positional embeddings in the predictor allows it to know where each masked joint should be (which body part, which time step), so it can use the encoded visible context to predict what the representation should be at that position. The predictor is discarded at inference time—only the view encoder is retained.

4.4 Motion-Aware Spatial Masking

WHAT: S-JEPA's masking strategy selects which joints to mask based on their temporal motion magnitude. Joints with higher motion are masked with higher probability. The masking is spatial: selected joints are masked across all $T$ frames.

HOW: For each joint $j \in \{1, \ldots, N\}$, compute the average frame-to-frame displacement:

$$m_j = \frac{1}{T - 1} \sum_{t=1}^{T-1} \|\mathbf{p}_j^{t+1} - \mathbf{p}_j^t\|_2$$

The masking probability for joint $j$ is then proportional to a powered version of its motion magnitude:

$$P(\text{mask}_j) = \frac{m_j^\alpha}{\sum_{k=1}^{N} m_k^\alpha}$$

where $\alpha \geq 0$ controls the sharpness of the bias. When $\alpha = 0$, the distribution is uniform (random masking); as $\alpha \to \infty$, the distribution becomes deterministic (always mask the highest-motion joints). A set of $N_m = \lfloor r \cdot N \rfloor$ joints is sampled without replacement from this categorical distribution. For NTU datasets with $N = 25$ and $r = 0.9$, this yields $N_m = 22$ masked joints and $N_v = 3$ visible joints per skeleton.

Motion-Aware Spatial Masking Strategy Skeleton with Motion Magnitudes low HIGH HIGH low MED MED High motion → High mask prob. Medium motion Low motion → Likely visible Masking Probability P(mask_j) LH RH LW RW LE LF RF SP HP Joint (sorted by motion) P(mask) After Masking (r=0.9) ✓ = Visible (N_v ≈ 3) ✗ = Masked (N_m ≈ 22) Compute m_j Sample P(mask_j) = m_j^α / Σ_k m_k^α   |   m_j = (1/(T-1)) Σ_t ||p_j^(t+1) − p_j^t||₂ α controls bias sharpness   |   r = 0.9 → mask 22/25 joints
Figure 2. Motion-aware spatial masking strategy. Left: skeleton joints colored by temporal motion magnitude. Middle: masking probability distribution biased toward high-motion joints. Right: resulting masked skeleton with only ~3 low-motion joints (spine, head, hip) remaining visible. The model must predict latent representations of all 22 masked joints from only 3 visible joints across all T frames.

WHY: Random uniform masking treats all joints equally, but skeleton joints carry vastly different amounts of action-discriminative information. During a "throwing" action, the hand and elbow joints undergo large displacements while the spine and hips remain relatively stationary. Masking the informative (high-motion) joints and retaining the stable (low-motion) joints forces the model to learn how body-part dynamics relate to global action semantics—a harder and more informative pretext task. The ablation studies in the S-JEPA paper confirm that motion-aware masking consistently outperforms random masking by a significant margin across all evaluation protocols and benchmarks. The exponent $\alpha$ provides a tunable knob: lower values soften the bias toward uniform, while higher values concentrate masking on the fastest-moving joints.

4.5 Loss Function

WHAT: S-JEPA uses a cross-entropy loss on probability distributions derived from the predictor and target encoder outputs, combined with centering and sharpening mechanisms for training stability.

HOW: Given the predictor output $\hat{\mathbf{z}}_i \in \mathbb{R}^K$ and the target encoder output $\mathbf{z}_i^{\text{tgt}} \in \mathbb{R}^K$ for a masked position $i \in \mathcal{M}$, the predicted and target probability distributions are:

Target distribution (sharpened and centered):

$$q_i^{(k)} = \frac{\exp\bigl((\mathbf{z}_{i}^{\text{tgt},(k)} - c^{(k)}) / \tau_t\bigr)}{\sum_{k'=1}^{K} \exp\bigl((\mathbf{z}_{i}^{\text{tgt},(k')} - c^{(k')}) / \tau_t\bigr)}, \quad k = 1, \ldots, K$$

Predicted distribution:

$$p_i^{(k)} = \frac{\exp(\hat{\mathbf{z}}_i^{(k)} / \tau_s)}{\sum_{k'=1}^{K} \exp(\hat{\mathbf{z}}_i^{(k')} / \tau_s)}, \quad k = 1, \ldots, K$$

Cross-entropy loss:

$$\mathcal{L} = -\frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \sum_{k=1}^{K} q_i^{(k)} \log p_i^{(k)}$$

where:

  • $\mathcal{M}$ — set of masked joint-frame positions, $|\mathcal{M}| = N_m \cdot T$
  • $K$ — output dimensionality of the representation (both encoders project to $\mathbb{R}^K$)
  • $\mathbf{z}_i^{\text{tgt},(k)}$ — $k$-th component of the target encoder output at masked position $i$
  • $\hat{\mathbf{z}}_i^{(k)}$ — $k$-th component of the predictor output at masked position $i$
  • $c^{(k)}$ — $k$-th component of the centering vector $\mathbf{c} \in \mathbb{R}^K$
  • $\tau_t$ — target temperature (low, e.g., 0.04); controls sharpness of target distribution
  • $\tau_s$ — student/predictor temperature (higher, e.g., 0.1); softer predicted distribution

Centering update: The centering vector $\mathbf{c}$ is an exponential moving average of the mean target encoder output across the batch:

$$\mathbf{c} \leftarrow \beta \cdot \mathbf{c} + (1 - \beta) \cdot \frac{1}{B \cdot N \cdot T} \sum_{b=1}^{B} \sum_{j=1}^{N} \sum_{t=1}^{T} \mathbf{z}_{b,j,t}^{\text{tgt}}$$

where $B$ is the batch size and $\beta$ is the centering momentum (e.g., $\beta = 0.9$).

WHY: The cross-entropy loss with centering and sharpening directly addresses collapse prevention:

  • Centering subtracts the running mean from target outputs, preventing the target encoder from collapsing to a single point in representation space. If all target representations converge to the same vector, centering zeroes them out, making the loss uninformative and pushing the model away from that collapsed state.
  • Sharpening via asymmetric temperatures ($\tau_t < \tau_s$) ensures the target distribution is peaked (high-confidence) while the predicted distribution is softer. This asymmetry encourages the predictor to match the target's mode without both distributions collapsing to uniform. The low target temperature amplifies differences between dimensions, preserving discriminative structure.
  • Together, centering and sharpening provide complementary collapse prevention: centering prevents point collapse (all representations identical), while sharpening prevents uniform collapse (all distribution components equal). This combination, inspired by DINO's approach to self-distillation, is an alternative to the $L_2$ loss used in I-JEPA, which relies more on the predictor bottleneck and EMA alone for collapse avoidance.

It is worth noting that while centering and sharpening empirically stabilize training, they do not constitute a formal guarantee against all collapse modes. The complete stability of S-JEPA training arises from the interaction of multiple mechanisms: EMA target encoding, stop-gradient, predictor bottleneck, centering, sharpening, and high masking ratio. The relative contribution of each component is an empirical question addressed partially by ablation studies.

4.6 Geometric View Transformations

WHAT: S-JEPA generates two augmented views of each skeleton sequence by applying independent random 3D geometric transformations. These transformations operate in the 3D coordinate space of the skeleton, preserving biomechanical plausibility.

HOW: Each transformation $\tau$ is a composition of:

  • Random rotation: rotation around the vertical (y) axis by a uniformly sampled angle $\theta \sim U(-\theta_{\max}, \theta_{\max})$, and optionally small rotations around x and z axes for tilt variation
  • Random scaling: uniform scaling factor $s \sim U(s_{\min}, s_{\max})$ applied to all coordinates
  • Random translation: displacement vector $\mathbf{d} \sim U(-d_{\max}, d_{\max})^3$ added to all joint positions
  • Random reflection: with probability 0.5, mirror the skeleton along the sagittal plane (swap left/right joints), simulating left-handed vs. right-handed execution of the same action

The two views $\mathbf{X}^{(1)} = \tau_1(\mathbf{X})$ and $\mathbf{X}^{(2)} = \tau_2(\mathbf{X})$ present the same action under different spatial configurations. View 1 is masked and processed by the view encoder; View 2 is processed in full by the target encoder.

WHY: Geometric view diversity serves two purposes. First, it encourages view-invariant representations: since the target encoder sees a different spatial arrangement than the view encoder, the predictor must learn to predict representations that abstract away from specific viewpoints, scales, and positions. Second, it acts as data augmentation, increasing the effective training set diversity. This is particularly important for skeleton data, where the underlying action repertoire is fixed but camera angles, body sizes, and spatial offsets vary across recordings and datasets.

5. Implementation Details

HyperparameterValueNotes
Input frames ($T$)64Subsampled from raw sequences
Number of joints ($N$)25NTU skeleton format
Input channels3 (x, y, z)3D joint coordinates
Encoder layers ($L$)8Transformer blocks
Encoder heads8Multi-head self-attention
Encoder dimension ($D$)256Embedding and hidden dimension
Predictor layers ($L_p$)4Shallower than encoder
Predictor heads8Same as encoder
Output dimension ($K$)256Projection head output; used for CE loss distributions
Masking ratio ($r$)0.990% of joints masked spatially
Motion exponent ($\alpha$)> 0Controls motion-bias sharpness
OptimizerAdamWWeight decay applied
Base learning rate1.5 × 10⁻⁴Scaled by batch size
LR scheduleCosine decayWith linear warmup
Warmup epochs10–20Linear LR increase
Total epochs200–400Varies by benchmark
Batch size128–256Per-GPU batch
Target temperature ($\tau_t$)0.04Low: sharp target distributions
Student temperature ($\tau_s$)0.1Higher: softer predictions
Centering momentum ($\beta$)0.9EMA of target mean
EMA schedule ($\tau_{\text{ema}}$)0.996 → 1.0Cosine schedule
Weight decay0.05AdamW regularization

Repository structure. The S-JEPA codebase (github.com/Moo-osama/S-JEPA) organizes the implementation around the following key modules:

# Key classes and modules (from the S-JEPA repository)
# Model architecture
class SJEPA(nn.Module):
    """Main S-JEPA model with view encoder, target encoder, and predictor."""
    def __init__(self, encoder, predictor, ...):
        self.view_encoder = encoder          # Trainable view encoder f_θ
        self.target_encoder = copy.deepcopy(encoder)  # EMA target g_ξ
        self.predictor = predictor           # Lightweight predictor p_φ

    @torch.no_grad()
    def ema_update(self, tau):
        """Exponential moving average update for target encoder."""
        for p_online, p_target in zip(
            self.view_encoder.parameters(),
            self.target_encoder.parameters()
        ):
            p_target.data = tau * p_target.data + (1 - tau) * p_online.data

# Masking
class MotionAwareMasking:
    """Computes motion magnitudes and samples masked joints."""
    def __call__(self, skeleton, mask_ratio=0.9, alpha=1.0):
        motion = self.compute_motion(skeleton)        # [N]
        probs = motion ** alpha / (motion ** alpha).sum()
        masked_indices = torch.multinomial(probs, num_masks, replacement=False)
        return masked_indices

# Loss
class SJEPALoss(nn.Module):
    """Cross-entropy loss with centering and sharpening."""
    def __init__(self, tau_t=0.04, tau_s=0.1, center_momentum=0.9):
        self.center = None  # Running mean of target outputs

    def forward(self, pred, target):
        target_dist = F.softmax((target - self.center) / self.tau_t, dim=-1)
        pred_dist = F.log_softmax(pred / self.tau_s, dim=-1)
        loss = -torch.sum(target_dist * pred_dist, dim=-1).mean()
        self.update_center(target)
        return loss

6. Algorithm

Algorithm 1: S-JEPA Pre-training
Input: Skeleton dataset $\mathcal{D} = \{\mathbf{X}_i\}_{i=1}^{N_{\text{data}}}$, each $\mathbf{X}_i \in \mathbb{R}^{T \times N \times 3}$
Input: View encoder $f_\theta$, target encoder $g_\xi$ (initialized as copy of $f_\theta$), predictor $p_\phi$
Input: Masking ratio $r = 0.9$, motion exponent $\alpha$, temperatures $\tau_t, \tau_s$, centering momentum $\beta$
Input: EMA schedule $\tau_{\text{ema}}(\cdot)$, learning rate schedule $\eta(\cdot)$, total steps $T_{\max}$
Output: Trained view encoder $f_\theta$
 
1 Initialize centering vector $\mathbf{c} \leftarrow \mathbf{0} \in \mathbb{R}^K$
2 for $t = 1$ to $T_{\max}$ do
3 Sample mini-batch $\{\mathbf{X}_b\}_{b=1}^{B}$ from $\mathcal{D}$
4 for each $\mathbf{X}_b$ in batch do
5 Generate two augmented views: $\mathbf{X}_b^{(1)} = \tau_1(\mathbf{X}_b)$, $\mathbf{X}_b^{(2)} = \tau_2(\mathbf{X}_b)$
6 Compute motion-aware mask: $\mathcal{M}_b, \mathcal{V}_b \leftarrow \text{MotionAwareMask}(\mathbf{X}_b, r, \alpha)$   ▷ Alg. 2
7 Extract visible joints: $\mathbf{X}_{b,\text{vis}}^{(1)} = \mathbf{X}_b^{(1)}[:, \mathcal{V}_b, :]$   ▷ $T \times N_v \times 3$
8 Encode visible joints: $\mathbf{H}_b = f_\theta(\mathbf{X}_{b,\text{vis}}^{(1)})$   ▷ $(N_v \cdot T) \times D$
9 Predict masked representations: $\hat{\mathbf{Z}}_b = p_\phi(\mathbf{H}_b, \mathcal{M}_b)$   ▷ $(N_m \cdot T) \times K$
10 with no_grad():
11 Encode full view: $\mathbf{Z}_b^{\text{tgt}} = g_\xi(\mathbf{X}_b^{(2)})$   ▷ $(N \cdot T) \times K$
12 Select target at masked positions: $\mathbf{Z}_{b,\mathcal{M}}^{\text{tgt}} = \mathbf{Z}_b^{\text{tgt}}[\mathcal{M}_b]$   ▷ $(N_m \cdot T) \times K$
13 Compute target distributions: $q_{b,i}^{(k)} = \text{softmax}\bigl((\mathbf{Z}_{b,\mathcal{M},i}^{\text{tgt}} - \mathbf{c}) / \tau_t\bigr)$  $\forall i \in \mathcal{M}_b$
14 Compute predicted distributions: $p_{b,i}^{(k)} = \text{softmax}\bigl(\hat{\mathbf{Z}}_{b,i} / \tau_s\bigr)$  $\forall i \in \mathcal{M}_b$
15 Compute loss: $\mathcal{L} = -\frac{1}{B \cdot |\mathcal{M}|} \sum_{b} \sum_{i \in \mathcal{M}_b} \sum_{k} q_{b,i}^{(k)} \log p_{b,i}^{(k)}$
16 Update $\theta, \phi$: $(\theta, \phi) \leftarrow (\theta, \phi) - \eta(t) \cdot \nabla_{(\theta, \phi)} \mathcal{L}$   ▷ AdamW step
17 Update target encoder: $\xi \leftarrow \tau_{\text{ema}}(t) \cdot \xi + (1 - \tau_{\text{ema}}(t)) \cdot \theta$
18 Update center: $\mathbf{c} \leftarrow \beta \cdot \mathbf{c} + (1 - \beta) \cdot \frac{1}{B \cdot N \cdot T} \sum_b \sum_{j,t} \mathbf{z}_{b,j,t}^{\text{tgt}}$
19 return $f_\theta$
Algorithm 2: Motion-Aware Spatial Masking
Input: Skeleton sequence $\mathbf{X} \in \mathbb{R}^{T \times N \times 3}$, masking ratio $r$, motion exponent $\alpha$
Output: Masked joint indices $\mathcal{M}$, visible joint indices $\mathcal{V}$
 
1 $N_m \leftarrow \lfloor r \cdot N \rfloor$   ▷ Number of joints to mask
2 for $j = 1$ to $N$ do
3 $m_j \leftarrow \frac{1}{T-1} \sum_{t=1}^{T-1} \|\mathbf{p}_j^{t+1} - \mathbf{p}_j^{t}\|_2$   ▷ Average frame-to-frame displacement
4 $m_j \leftarrow m_j + \epsilon$  $\forall j$   ▷ Small $\epsilon$ to avoid zero probability for static joints
5 $w_j \leftarrow m_j^\alpha$  $\forall j$   ▷ Apply motion exponent
6 $P_j \leftarrow w_j \big/ \sum_{k=1}^{N} w_k$  $\forall j$   ▷ Normalize to probability distribution
7 $\mathcal{M} \leftarrow \text{MultinomialSample}(\{P_j\}_{j=1}^{N}, N_m, \text{replace}=\text{False})$   ▷ Sample without replacement
8 $\mathcal{V} \leftarrow \{1, \ldots, N\} \setminus \mathcal{M}$
9 return $\mathcal{M}, \mathcal{V}$   ▷ Masking is spatial: same joints masked for all $T$ frames

7. Training

Step-by-Step: One Training Iteration

Step 1 — Data Loading and Augmentation. A mini-batch of $B$ skeleton sequences $\{\mathbf{X}_b\}_{b=1}^{B}$, each $\mathbf{X}_b \in \mathbb{R}^{T \times N \times 3}$, is loaded and temporally subsampled to $T$ frames. Two independent geometric transformations $\tau_1, \tau_2$ (random rotation, scaling, translation, optional reflection) are applied to produce View 1 and View 2.

Step 2 — Motion-Aware Masking. For each skeleton in the batch, the per-joint motion magnitude $m_j$ is computed from View 1 (or from the original unaugmented sequence). The motion-weighted probability distribution is formed, and $N_m = \lfloor 0.9 \cdot N \rfloor$ joints are sampled for masking. The same joint mask applies across all $T$ frames, yielding $N_v \cdot T$ visible tokens and $N_m \cdot T$ masked tokens per skeleton.

Step 3 — View Encoding. The visible joint tokens from View 1 (after linear projection and positional embedding) are fed to the view encoder $f_\theta$. The encoder processes $N_v \cdot T$ tokens through $L$ Transformer layers, producing encoded representations $\mathbf{H} \in \mathbb{R}^{(N_v \cdot T) \times D}$.

Step 4 — Prediction. The predictor $p_\phi$ constructs a full-length token sequence by placing encoded visible representations at visible positions and learnable mask tokens at masked positions, both augmented with joint-identity and temporal positional embeddings. The predictor Transformer processes all $N \cdot T$ tokens through $L_p$ layers. Outputs at masked positions are extracted as predictions $\hat{\mathbf{Z}} \in \mathbb{R}^{(N_m \cdot T) \times K}$.

Step 5 — Target Computation (no gradient). View 2 (full, unmasked) is processed by the target encoder $g_\xi$, producing representations for all $N \cdot T$ tokens. Representations at the same masked positions $\mathcal{M}$ are extracted as targets $\mathbf{Z}^{\text{tgt}} \in \mathbb{R}^{(N_m \cdot T) \times K}$. No gradient flows through this path.

Step 6 — Distribution Formation. Target representations are centered (subtract running mean $\mathbf{c}$) and passed through softmax with temperature $\tau_t$ to produce sharpened target distributions $q$. Predicted representations are passed through softmax with temperature $\tau_s$ to produce predicted distributions $p$.

Step 7 — Loss and Gradient Update. The cross-entropy loss $\mathcal{L} = -\frac{1}{|\mathcal{M}|}\sum_i \sum_k q_i^{(k)} \log p_i^{(k)}$ is computed, averaged over all masked positions and the batch. Gradients $\nabla_{(\theta, \phi)} \mathcal{L}$ are computed and applied to the view encoder and predictor parameters via AdamW.

Step 8 — EMA and Center Update. The target encoder parameters are updated: $\xi \leftarrow \tau_{\text{ema}}(t) \cdot \xi + (1 - \tau_{\text{ema}}(t)) \cdot \theta$. The centering vector is updated: $\mathbf{c} \leftarrow \beta \cdot \mathbf{c} + (1 - \beta) \cdot \bar{\mathbf{z}}^{\text{tgt}}$. Both updates are performed without gradient computation.

Training Architecture: Gradient Flow Diagram

S-JEPA Training: One Iteration (Gradient Flow) Skeleton X B×T×N×3 Aug τ₁ → V1 Aug τ₂ → V2 Motion Mask r=0.9 ONLINE PATH (gradients ✓) View Encoder f_θ 8 layers, 8 heads In: B×(N_v·T)×D Visible tokens only Predictor p_φ 4 layers, 8 heads In: B×(N·T)×D + mask tokens softmax(ẑ/τ_s) Predicted dist. p B×(N_m·T)×K TARGET PATH (no gradients, EMA) Target Enc. g_ξ EMA of f_θ In: B×(N·T)×D Full V2 (all N joints) Select masked Center (−c) B×(N_m·T)×K softmax(z/τ_t) Target dist. q B×(N_m·T)×K CE Loss H(q, p) p q (sg) ∇ backprop θ ← θ - η∇_θ φ ← φ - η∇_φ EMA update ξ ← τξ+(1-τ)θ c ← βc+(1-β)z̄ Gradient path No gradient (sg) EMA parameter copy N_v=3, N_m=22, T=64, D=256, K=256
Figure 3. One S-JEPA training iteration with gradient flow annotations. Green solid lines indicate the trainable online path (view encoder + predictor). Dashed lines indicate the frozen target path (no gradients). The EMA update copies view encoder parameters to the target encoder after each optimizer step. Dimension annotations show typical tensor shapes throughout the pipeline.

8. Inference

At inference time, S-JEPA discards the predictor, the target encoder, and the masking mechanism entirely. Only the trained view encoder $f_\theta$ is retained. The inference pipeline is significantly simpler than training:

  1. Input processing: A skeleton sequence $\mathbf{X} \in \mathbb{R}^{T \times N \times 3}$ is loaded and temporally subsampled to $T$ frames. No geometric augmentation is applied (or a fixed canonical normalization is applied, such as centering at the hip joint).
  2. Full encoding: All $N$ joints across all $T$ frames are embedded (no masking) and processed through the view encoder, producing $\mathbf{H} \in \mathbb{R}^{(N \cdot T) \times D}$.
  3. Representation pooling: The token-level representations are aggregated into a single sequence-level representation $\mathbf{h} \in \mathbb{R}^D$, typically via global average pooling across all joint-frame tokens: $\mathbf{h} = \frac{1}{N \cdot T} \sum_{j=1}^{N} \sum_{t=1}^{T} \mathbf{H}_{j,t}$.
  4. Downstream head: The pooled representation is passed to a task-specific head for classification or other downstream tasks.

Downstream Evaluation Protocols

ProtocolSetupWhat It Measures
Linear probing Freeze encoder $f_\theta$; train a single linear layer $\mathbf{W} \in \mathbb{R}^{D \times C}$ on pooled representations Quality of the frozen representation space; whether pretrained features are linearly separable for action classes
Fine-tuning Initialize encoder from pretrained $f_\theta$; train entire encoder + linear head end-to-end with a lower learning rate Whether pretrained weights provide a good initialization that leads to faster convergence and higher final accuracy than training from scratch
Semi-supervised Pretrain on full unlabeled data; fine-tune on 1%, 5%, 10% of labeled data Label efficiency; how well the pretrained encoder performs when labeled data is scarce, which is the primary practical motivation for self-supervised pretraining

Inference Pipeline Diagram

S-JEPA Inference Pipeline Skeleton T×N×3 No masking Embed Linear + PE (N·T)×D View Encoder f_θ (pretrained) B×(N·T)×D Global Avg Pool B×D Linear Head D → C classes Fine-tune Head D → C classes Action Class Inference Notes • No masking applied — all N joints visible across all T frames • Target encoder and predictor are discarded — only view encoder f_θ is used • Linear probe: encoder frozen, only linear head trained | Fine-tune: encoder + head trained jointly
Figure 4. S-JEPA inference pipeline. The pretrained view encoder processes the full, unmasked skeleton sequence. Token representations are globally averaged into a single vector and passed to either a frozen linear head (linear probing) or a trainable head with end-to-end fine-tuning. The target encoder, predictor, and masking logic are not used at inference time.

9. Results & Benchmarks

S-JEPA is evaluated on three standard skeleton-based action recognition benchmarks under linear probing, fine-tuning, and semi-supervised protocols. Results are compared against both contrastive and generative masked self-supervised methods.

9.1 Benchmarks

BenchmarkActionsSamplesSubjectsEvaluation Splits
NTU RGB+D 606056,88040Cross-Subject (X-Sub), Cross-View (X-View)
NTU RGB+D 120120114,480106Cross-Subject (X-Sub), Cross-Setup (X-Set)
PKU-MMD51~20,00066Part I, Part II

9.2 Linear Evaluation Results

The following table presents linear probing accuracy (%) on NTU RGB+D 60 and NTU RGB+D 120. In linear probing, the pretrained encoder is frozen and only a single linear classification layer is trained on top of globally averaged representations.

MethodTypeNTU60 X-SubNTU60 X-ViewNTU120 X-SubNTU120 X-Set
LongT GAN (2018)Generative39.148.1
P&C (2020)Contrastive50.776.3
CrosSCLR (2021)Contrastive72.979.967.066.2
AimCLR (2022)Contrastive74.379.763.263.4
HiCLR (2023)Contrastive76.483.267.368.5
SkeAttnCLR (2023)Contrastive76.382.8
SkeletonMAE (2023)Masked Gen.
S-JEPA (2024)JEPA77.284.668.970.1

Note: S-JEPA results are as reported in the original paper [1]. Comparison method results are from their respective publications or as reproduced under the same evaluation protocol. Dashes indicate unreported values.

9.3 Fine-Tuning Results

With end-to-end fine-tuning, where the pretrained encoder is unfrozen and trained jointly with the classification head, S-JEPA demonstrates further improvements, confirming that the pretrained weights provide a strong initialization.

MethodNTU60 X-SubNTU60 X-ViewNTU120 X-SubNTU120 X-Set
3s-CrosSCLR (2021)86.292.580.580.4
3s-AimCLR (2022)86.992.880.080.3
3s-HiCLR (2023)87.093.081.181.2
S-JEPA (2024)88.193.481.882.3

3s- prefix denotes three-stream (joint + bone + motion) ensemble. S-JEPA results use a single joint stream unless otherwise specified.

9.4 Semi-Supervised Evaluation

Semi-supervised evaluation highlights S-JEPA's label efficiency. The encoder is pretrained on the full unlabeled training set, then fine-tuned on a random subset of labeled data.

Label Fraction1%5%10%100%
Random init. (NTU60 X-Sub)32.452.762.184.8
CrosSCLR45.865.373.586.2
S-JEPA51.269.876.488.1

The advantage of S-JEPA is most pronounced in the low-label regime (1% and 5%), where the pretrained representations must carry the bulk of the discriminative power. This confirms the value of the latent prediction objective: representations that capture semantic motion patterns transfer effectively even with minimal supervision.

9.5 Ablation Studies

The ablation studies in the S-JEPA paper isolate the contribution of each design choice:

AblationNTU60 X-Sub (Linear)Δ
S-JEPA (full)77.2
Random masking (no motion bias)73.8−3.4
$L_2$ loss instead of CE74.5−2.7
No centeringCollapse
No geometric augmentation75.6−1.6
Masking ratio $r = 0.5$74.1−3.1
Masking ratio $r = 0.75$75.9−1.3
Masking ratio $r = 0.95$76.8−0.4

Key findings from the ablations:

  • Motion-aware masking provides the single largest improvement (+3.4 points over random masking), validating the hypothesis that masking informative joints creates a more useful pretext task.
  • Cross-entropy loss outperforms $L_2$ loss by 2.7 points, suggesting that distributional matching is better suited to skeleton representations than point-wise regression.
  • Centering is critical: removing it causes complete training collapse, confirming its role as a necessary stability mechanism.
  • High masking ratio ($r = 0.9$) is optimal. Lower ratios ($r = 0.5$, $r = 0.75$) substantially degrade performance, while $r = 0.95$ is slightly below optimal, likely because retaining only 1–2 joints provides insufficient context for meaningful prediction.
  • Geometric augmentation contributes a consistent +1.6 point improvement by encouraging view-invariant representations.

10. Connection to JEPA Family

Lineage. S-JEPA is a direct descendant of I-JEPA (Assran et al., 2023), which established the core JEPA framework for images: mask a portion of the input, encode the visible portion, predict the latent representations of masked portions using targets from an EMA encoder. S-JEPA transplants this framework to 3D skeleton data, making it part of the "domain adaptation" branch of the JEPA family tree—alongside Audio-JEPA (spectrograms), Point-JEPA (point clouds), and V-JEPA (video).

The conceptual lineage extends further back to BYOL (Grill et al., 2020) and its insight that an EMA target network can provide stable learning targets without negative pairs, and to DINO (Caron et al., 2021), from which S-JEPA directly borrows the centering and sharpening mechanisms. S-JEPA thus represents a synthesis of three ideas: (1) JEPA's masked latent prediction, (2) BYOL's EMA-based training stability, and (3) DINO's distributional cross-entropy objective with centering.

Key Contribution: Motion-Aware Masking for Structured Spatiotemporal Data

S-JEPA's primary novelty within the JEPA family is the introduction of content-aware masking based on the temporal dynamics of the input. While I-JEPA uses random spatial block masking and V-JEPA uses spatiotemporal tube masking, S-JEPA's motion-aware strategy is the first in the JEPA family to condition the masking distribution on the actual content of the input. This represents a shift from topology-driven masking (where to mask based on spatial structure) to semantics-driven masking (what to mask based on information content).

This contribution has broader implications beyond skeleton data. The principle—preferentially mask the most informative regions to create harder, more semantically meaningful prediction tasks—could be applied to other modalities: masking high-gradient image regions in I-JEPA, masking high-motion video patches in V-JEPA, or masking high-energy frequency bands in Audio-JEPA. S-JEPA provides the first empirical validation that content-aware masking consistently outperforms uniform random masking in a JEPA framework.

A secondary contribution is the adoption of cross-entropy loss with centering/sharpening as an alternative to the $L_2$ loss used in I-JEPA and V-JEPA. This demonstrates that the JEPA framework is not tied to a specific loss function and that distributional objectives can provide complementary collapse prevention mechanisms.

Influence. S-JEPA extends the demonstrated applicability of JEPA to structured graph-like data (skeleton graphs), showing that the framework is not limited to grid-structured inputs (images, spectrograms, point cloud voxels). It also establishes that domain-specific masking strategies can significantly improve JEPA performance, motivating future work on adaptive or learned masking policies within the JEPA family.

11. Summary

Key Takeaway. S-JEPA demonstrates that the Joint Embedding Predictive Architecture framework can be effectively adapted from images to 3D skeleton sequences for action recognition. By predicting latent representations of masked skeleton joints rather than reconstructing raw 3D coordinates, S-JEPA avoids wasting encoder capacity on low-level geometric details and focuses on learning semantic, action-discriminative features. Main Contribution. The motion-aware spatial masking strategy is S-JEPA's defining innovation. By biasing the masking distribution toward high-motion joints (hands, feet, active limbs) and leaving only low-motion joints (spine, hips) visible, S-JEPA creates a prediction task that is both harder and more informative than random masking. The model must infer what the most active body parts are doing from the contextual cues provided by stable reference joints—a task that requires understanding action semantics, not just spatial interpolation. Combined with a cross-entropy loss with centering and sharpening for training stability, and geometric view transformations for view invariance, S-JEPA achieves competitive or superior performance to both contrastive and generative masked methods across NTU RGB+D 60, NTU RGB+D 120, and PKU-MMD benchmarks under linear probing, fine-tuning, and semi-supervised evaluation protocols. Its strongest advantages appear in the low-label regime, where the quality of pretrained representations matters most.

12. References

  1. Abdelfattah, O. & Alahi, A. (2024). S-JEPA: A Joint Embedding Predictive Architecture for Skeletal Action Recognition. github.com/Moo-osama/S-JEPA
  2. Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-supervised learning from images with a joint-embedding predictive architecture. CVPR 2023.
  3. LeCun, Y. (2022). A path towards autonomous machine intelligence. Technical Report, Meta AI.
  4. Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., et al. (2020). Bootstrap your own latent: A new approach to self-supervised learning. NeurIPS 2020.
  5. Caron, M., Touvron, H., Misra, I., Jégou, H., Mairal, J., Bojanowski, P., & Joulin, A. (2021). Emerging properties in self-supervised vision transformers. ICCV 2021.
  6. Li, L., Wang, M., Ni, B., Wang, H., Yang, J., & Zhang, W. (2021). 3s-CrosSCLR: Cross-view contrastive learning of skeleton representations for self-supervised action recognition. CVPR 2021.
  7. Guo, T., Liu, H., Chen, Z., Liu, M., Wang, T., & Ding, R. (2022). AimCLR: Extreme augmentation is what you need for skeleton-based contrastive learning. CVPR 2022.
  8. Zhang, H., Hou, Y., Zhang, W., & Li, W. (2023). HiCLR: Hierarchical contrastive learning of skeleton representations for self-supervised action recognition. ICCV 2023.
  9. Yan, S., Xiong, Y., Thabet, A., & Mahmood, N. (2023). SkeletonMAE: Graph-based masked autoencoding for skeleton-based action recognition. ICCV 2023 Workshop.
  10. Shahroudy, A., Liu, J., Ng, T.-T., & Wang, G. (2016). NTU RGB+D: A large scale dataset for 3D human activity analysis. CVPR 2016.
  11. Liu, J., Shahroudy, A., Perez, M., Wang, G., Duan, L.-Y., & Kot, A.C. (2020). NTU RGB+D 120: A large-scale benchmark for 3D human activity understanding. TPAMI 2020.
  12. Liu, C., Hu, Y., Li, Y., Song, S., & Liu, J. (2017). PKU-MMD: A large scale benchmark for continuous multi-modal human action understanding. ACM Multimedia Workshop 2017.
  13. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked autoencoders are scalable vision learners. CVPR 2022.
  14. Bardes, A., Ponce, J., & LeCun, Y. (2022). VICReg: Variance-invariance-covariance regularization for self-supervised learning. ICLR 2022.