AuthorsHu, Cheng, Xie, Li, Zhu
Date2024-09
Category3D / Spatial
Derives fromI-JEPA

3D-JEPA: A Joint Embedding Predictive Architecture for 3D Self-Supervised Representation Learning

Hu, Cheng, Xie, Li, Zhu • September 2024 • arXiv 2409.15803

1. Introduction

Self-supervised learning on 3D point clouds has historically been dominated by reconstruction-based methods. Approaches such as Point-MAE and Point-BERT mask portions of a point cloud and train the network to reconstruct the missing 3D coordinates or discrete point tokens in input space. While effective, this paradigm carries an inherent limitation: the model expends significant capacity learning to reproduce low-level geometric details—surface normals, local curvatures, exact point positions—that may be irrelevant for high-level semantic understanding. A decoder powerful enough to reconstruct raw point coordinates can succeed without ever learning transferable, abstract representations.

In the 2D image domain, the Joint Embedding Predictive Architecture (JEPA) framework, and specifically I-JEPA (Assran et al., 2023), demonstrated that predicting in representation space rather than input space yields representations that are more semantically rich and more sample-efficient for downstream tasks. I-JEPA removes the decoder entirely: a lightweight predictor network maps context representations to target representations produced by an exponential moving average (EMA) encoder, and the loss is computed purely in the latent space. This avoids the pixel-reconstruction bottleneck and encourages the encoder to capture high-level, abstract features.

3D-JEPA (Hu, Cheng, Xie, Li, Zhu, 2024) transfers this principle to the 3D domain. The core question it addresses is: Can latent-space prediction replace point-coordinate reconstruction for self-supervised 3D representation learning, and if so, how should the masking strategy be adapted for the irregular, unordered geometry of point clouds?

The contributions of 3D-JEPA are threefold:

  1. Latent prediction for 3D point clouds. 3D-JEPA is the first work to apply the JEPA framework—context-encoder, EMA target-encoder, predictor, representation-space loss—to 3D point cloud data, eliminating the need for point-coordinate or discrete-token reconstruction.
  2. Multi-block spatial sampling in 3D. Extending I-JEPA's multi-block masking from 2D image grids to 3D Euclidean space, the paper proposes a spatial sampling strategy that selects multiple target blocks as spatially contiguous clusters in 3D, using the complement as context. This ensures the prediction task requires genuine 3D spatial reasoning rather than simple interpolation.
  3. Competitive performance without reconstruction. 3D-JEPA achieves results competitive with or superior to Point-MAE and Point-BERT on standard benchmarks (ModelNet40, ScanObjectNN, ShapeNetPart) despite never reconstructing a single point coordinate during pretraining.

The key departure from I-JEPA lies in the input modality: images are dense, grid-structured signals where patches can be extracted by regular spatial tiling. Point clouds are sparse, unordered, and irregularly distributed. 3D-JEPA must address point cloud tokenization (converting raw points into a sequence of patch tokens), spatial block sampling in continuous 3D Euclidean space (rather than on a discrete 2D grid), and positional encoding that respects 3D geometry. These adaptations are non-trivial and constitute the primary methodological contribution of the work.

2. Method

To understand 3D-JEPA, consider an analogy. Imagine you are shown a sculpture from one angle—say the front and left side—and asked to describe what the back looks like. A novice might try to predict the exact surface texture, the precise curve of each ridge. An expert, by contrast, would reason at a higher level: "Given the style and symmetry I see, the back likely has a similar motif with these broad structural features." The expert is predicting in concept space, not surface space.

3D-JEPA follows the expert's strategy. Given a partial view of a 3D object (the context), it does not attempt to reconstruct the missing points. Instead, it predicts what a separate, stable encoder would say about those missing regions—their abstract, high-level representations. The prediction target is not geometry; it is semantics.

Intuition: Why predict representations instead of points?
Consider two patches on opposite sides of a chair. Their raw 3D coordinates are completely different, but semantically they are both "leg of a chair." A reconstruction loss treats them as unrelated prediction targets. A representation-space loss, by contrast, can recognize their semantic similarity because the target encoder has already projected them into a space where function matters more than position. By predicting in this space, the context encoder is forced to learn what things mean, not just where they are.

The training pipeline has four core components:

  1. Point cloud tokenization. The raw point cloud (e.g., 1024 points) is divided into $N$ spatial patches using Farthest Point Sampling (FPS) to select center points, followed by K-Nearest Neighbors (KNN) to group nearby points into each patch. A small PointNet-style network (mini-PointNet) encodes each patch into a $D$-dimensional token. This converts the irregular point cloud into an ordered sequence of tokens, analogous to image patch embeddings in a Vision Transformer.
  2. Multi-block spatial masking. Multiple target blocks are sampled by selecting seed patches and expanding outward through spatial adjacency in 3D. The remaining patches form the context. Because blocks are spatially contiguous in 3D, the model cannot trivially interpolate the targets from nearby context patches—it must reason about the global shape.
  3. Dual encoders with asymmetric processing. The context encoder (trainable) processes only the visible context tokens. The target encoder (EMA-updated, frozen for gradients) processes the full set of tokens—or specifically the target tokens—to produce the prediction targets. The asymmetry (partial input for context, full input for target) is critical: it prevents the model from simply copying through an identity mapping.
  4. Predictor. A lightweight Transformer takes the context encoder's output along with positional information for the masked locations and predicts the target encoder's representations at those locations. The loss is the mean squared error (or smooth L1) between the predictor's output and the target encoder's output at the masked positions.
Intuition: Why multi-block masking matters in 3D.
In 2D images, masking a single large rectangular block is already a challenging task. In 3D, however, objects are sparser and spatial relationships are more complex—there is no canonical "up-down-left-right" grid. By sampling multiple spatially disjoint target blocks, 3D-JEPA ensures the model must integrate information across the entire visible context to predict each block. A single contiguous mask could often be resolved by local extrapolation; multiple blocks demand global understanding of the object's structure.

Crucially, no component of 3D-JEPA ever attempts to decode back to 3D point coordinates. The entire training signal is in representation space. This is what distinguishes JEPA-family methods from masked autoencoders and makes the learned representations inherently abstract and semantic.

3. Model Overview

At-a-Glance

Attribute3D-JEPA
Input3D point cloud ($P$ points, e.g., 1024 points × 3 coords)
TokenizationFPS + KNN grouping → mini-PointNet per patch → $N$ tokens of dim $D$
MaskingMulti-block spatial sampling in 3D Euclidean space; $M$ target blocks
Context EncoderStandard Transformer (trainable); processes context tokens only
Target EncoderSame architecture, EMA-updated (no gradients)
PredictorLightweight Transformer; maps context representations → target representations
LossL2 (MSE) in representation space between predictor output and target encoder output
Key ResultCompetitive with Point-MAE / Point-BERT on ModelNet40, ScanObjectNN, ShapeNetPart without point reconstruction
Derives FromI-JEPA (Assran et al., 2023)

Training Architecture Diagram

3D-JEPA Training Architecture Point Cloud P × 3 Tokenizer FPS+KNN+MiniPN Patch Tokens N × D Multi-Block Mask M targets, 1 context Context Tokens N_c × D Target Tokens N_t × D Context Encoder Transformer (trainable) → N_c × D Target Encoder EMA (frozen) → N_t × D EMA update Predictor Lightweight Transformer → N_t × D (predicted) L2 Loss representation space ∇θ sg (no grad)
Figure 1. 3D-JEPA training architecture. The point cloud is tokenized into $N$ patch tokens. Multi-block masking splits tokens into context ($N_c$) and target ($N_t$) sets. The trainable context encoder processes context tokens; the EMA target encoder (dashed border, no gradients) produces prediction targets. The predictor maps context representations to target representations, with L2 loss computed in representation space. Gradient flows only through the context encoder and predictor (solid green arrows).

4. Main Components of 3D-JEPA

4.1 Point Cloud Tokenizer

WHAT: The tokenizer converts a raw point cloud $\mathbf{X} \in \mathbb{R}^{P \times 3}$ (with $P$ points, each having 3D coordinates) into a sequence of $N$ patch tokens $\{\mathbf{t}_i\}_{i=1}^{N}$, where each $\mathbf{t}_i \in \mathbb{R}^D$.

HOW: The tokenization follows the standard pipeline established by Point-BERT and Point-MAE:

  1. Center selection via FPS. Farthest Point Sampling selects $N$ center points $\{c_1, \dots, c_N\}$ from $\mathbf{X}$. FPS guarantees approximately uniform coverage of the point cloud's spatial extent. In 3D-JEPA, $N = 64$ patches are used (following convention from prior work).
  2. Neighbor grouping via KNN. For each center $c_i$, the $K$ nearest neighbors (typically $K = 32$) are gathered to form a local patch $\mathbf{G}_i \in \mathbb{R}^{K \times 3}$. Coordinates are normalized relative to the center: $\mathbf{G}_i^{\text{norm}} = \mathbf{G}_i - c_i$.
  3. Mini-PointNet embedding. Each normalized patch is processed by a shared mini-PointNet (an MLP followed by max-pooling over the $K$ points) to produce a $D$-dimensional token: $$\mathbf{t}_i = \text{MaxPool}(\text{MLP}(\mathbf{G}_i^{\text{norm}})) \in \mathbb{R}^D$$
  4. Positional encoding. The 3D coordinates of each center point are mapped to a $D$-dimensional positional embedding via a learned linear projection or sinusoidal encoding: $$\mathbf{p}_i = \text{PosEnc}(c_i) \in \mathbb{R}^D$$ The final input token is $\mathbf{t}_i + \mathbf{p}_i$.

WHY: This tokenization scheme is critical because it bridges the gap between irregular point clouds and the regular token sequences required by Transformer architectures. The use of FPS ensures spatial uniformity, KNN captures local geometric structure, and the mini-PointNet provides permutation-invariant encoding of each local patch. The paper follows the established $N=64$, $K=32$ configuration to enable direct comparison with Point-MAE and Point-BERT.

4.2 Context Encoder

WHAT: The context encoder $f_\theta$ is a standard Transformer that processes the visible (unmasked) context tokens and produces contextualized representations.

HOW: Given the set of context token indices $\mathcal{C}$ (the complement of the masked target indices), the context encoder receives the tokens $\{\mathbf{t}_i + \mathbf{p}_i\}_{i \in \mathcal{C}}$ and applies $L$ Transformer layers with multi-head self-attention and feed-forward networks:

$$\mathbf{h}_i^{(l+1)} = \text{TransformerLayer}^{(l)}(\{\mathbf{h}_j^{(l)}\}_{j \in \mathcal{C}}), \quad i \in \mathcal{C}$$

where $\mathbf{h}_i^{(0)} = \mathbf{t}_i + \mathbf{p}_i$. The architecture uses a standard ViT-Base-like configuration: $L=12$ layers, hidden dimension $D=384$, $h=6$ attention heads, and FFN hidden dimension $4D = 1536$. The output is $\{\mathbf{h}_i\}_{i \in \mathcal{C}}$, each in $\mathbb{R}^D$.

WHY: Processing only context tokens (not the full sequence with mask tokens) is a deliberate design choice inherited from I-JEPA. It serves two purposes: (1) computational efficiency, since the encoder never processes mask tokens, and (2) it prevents information leakage—the encoder must form its representations without any positional hint from mask-token placeholders about what is missing.

4.3 Target Encoder (EMA)

WHAT: The target encoder $f_{\bar{\theta}}$ has the identical architecture as the context encoder but its parameters $\bar{\theta}$ are updated via exponential moving average (EMA) of the context encoder's parameters $\theta$. No gradients flow through the target encoder.

HOW: After each training step, the target encoder parameters are updated as:

$$\bar{\theta} \leftarrow \tau \bar{\theta} + (1 - \tau) \theta$$

where $\tau$ is the EMA momentum coefficient, typically following a cosine schedule from an initial value (e.g., $\tau_0 = 0.996$) to a final value (e.g., $\tau_1 = 1.0$) over the course of training. The target encoder processes the target tokens $\{\mathbf{t}_i + \mathbf{p}_i\}_{i \in \mathcal{T}}$ (where $\mathcal{T}$ is the set of target patch indices) and produces target representations $\{\mathbf{z}_i\}_{i \in \mathcal{T}}$.

WHY: The EMA target encoder serves as a slowly evolving reference that provides stable prediction targets. This mechanism, inherited from BYOL and refined in I-JEPA, is one of the key components that prevents representational collapse—the scenario where the model learns to map all inputs to a constant representation. The EMA encoder changes slowly enough to provide consistent targets across training steps, but adapts over time as the context encoder improves. The stop-gradient operation (no backpropagation through the target encoder) creates the necessary asymmetry: the context encoder and predictor must actively improve their representations to match the target, rather than the target degrading to match a trivial prediction.

Note on collapse prevention. While the EMA mechanism is a key contributor to stability, collapse prevention in JEPA-family methods results from the combination of several factors: (1) EMA target stability, (2) stop-gradient asymmetry, (3) predictor bottleneck constraining information flow, and (4) the multi-block masking strategy itself providing a sufficiently challenging prediction task. No single factor is sufficient in isolation.

4.4 Predictor

WHAT: The predictor $g_\phi$ is a lightweight Transformer that takes the context encoder's output representations and positional information for the masked target locations, and predicts the target encoder's representations at those locations.

HOW: The predictor receives two types of input:

  1. The contextualized representations from the context encoder: $\{\mathbf{h}_i\}_{i \in \mathcal{C}}$, paired with their positional embeddings $\{\mathbf{p}_i\}_{i \in \mathcal{C}}$.
  2. Learnable mask tokens $\{\mathbf{m}\}$ (a single shared vector, repeated for each target position), paired with the positional embeddings of the target locations: $\{\mathbf{p}_j\}_{j \in \mathcal{T}}$.

These are concatenated into a single sequence and processed by a smaller Transformer (e.g., $L_p = 6$ layers, same hidden dimension $D$). The predictor's output at the target positions yields predicted representations:

$$\hat{\mathbf{z}}_j = g_\phi(\{\mathbf{h}_i\}_{i \in \mathcal{C}}, \{\mathbf{p}_j\}_{j \in \mathcal{T}})_j, \quad j \in \mathcal{T}$$

WHY: The predictor serves a dual role. First, it bridges the gap between the context encoder's representation space and the target encoder's representation space—since the two encoders' parameters diverge (one is EMA-updated), their representations inhabit slightly different manifolds. Second, the predictor's limited capacity (fewer layers than the encoder) acts as an information bottleneck that regularizes learning. If the predictor were too powerful, it could learn to predict target representations without the context encoder needing to produce informative representations. The narrow predictor forces the context encoder to do the heavy lifting of semantic encoding.

4.5 Multi-Block Spatial Masking Strategy

WHAT: The masking strategy determines which patches become targets (to be predicted) and which remain visible as context. 3D-JEPA uses a multi-block approach: multiple spatially contiguous blocks are sampled as targets, with the remaining patches forming the context.

HOW: The multi-block sampling procedure operates as follows:

  1. Seed selection. $M$ seed patches (e.g., $M = 4$) are randomly selected from the $N$ patch centers.
  2. Block expansion. Each seed is expanded into a spatially contiguous block by including its $k$-nearest neighboring patches in 3D Euclidean space. The expansion radius or neighbor count is chosen to achieve a target masking ratio (e.g., each block covers approximately $\frac{N \cdot r}{M}$ patches, where $r$ is the total target ratio, typically $r \approx 0.6$, i.e., 60% of patches are masked as targets).
  3. Context assignment. All patches not included in any target block form the context set $\mathcal{C}$. Since blocks may overlap slightly, the effective context ratio is approximately $1 - r$.

The key hyperparameters are: number of target blocks $M$, target masking ratio $r$, and expansion strategy (fixed-$k$ neighbors or radius-based). The paper explores $M \in \{1, 2, 4\}$ and finds that $M = 4$ provides the best performance, validating the multi-block design.

Multi-Block Spatial Masking in 3D Full Point Cloud (N=64 patches) sample M=4 blocks After Multi-Block Masking Target 1 Target 2 Target 3 Target 4 Legend Target (masked) Context (visible) ~60% target, ~40% context Multi-Block Sampling Procedure 1. Select M=4 seed patches via random sampling from N centers 2. Expand each seed to k-nearest neighbors in 3D Euclidean space 3. Context = all patches NOT in any target block → feed to context encoder
Figure 2. Multi-block spatial masking in 3D-JEPA. Four target blocks (green outlines) are sampled as spatially contiguous clusters in 3D Euclidean space. The remaining patches (muted outlines) form the context. Each block is generated by selecting a seed patch and expanding to its spatial neighbors. This ensures the prediction task requires non-local 3D reasoning.

WHY: The ablation on the number of target blocks is one of the paper's key findings. Using $M=1$ (a single contiguous block) yields weaker representations because the model can often predict the single target by local extrapolation from nearby context patches. Increasing to $M=4$ forces the model to maintain a global understanding of the object: predicting four spatially separated blocks from the remaining context requires integrating information across the entire visible shape. This design directly mirrors the finding from I-JEPA that multi-block masking outperforms single-block masking on 2D images, and validates that the same principle holds in the 3D domain.

4.6 Loss Function

WHAT: The loss function measures the discrepancy between the predictor's output and the target encoder's output at the masked positions. It operates entirely in representation space—no point coordinates are involved.

HOW: Let $\hat{\mathbf{z}}_j$ be the predictor's output for target position $j \in \mathcal{T}$, and let $\mathbf{z}_j$ be the target encoder's output for the same position. The loss is the average mean squared error over all target positions across all target blocks:

$$\mathcal{L} = \frac{1}{|\mathcal{T}|} \sum_{j \in \mathcal{T}} \| \hat{\mathbf{z}}_j - \text{sg}(\mathbf{z}_j) \|_2^2$$

where:

  • $\hat{\mathbf{z}}_j \in \mathbb{R}^D$ is the predictor's predicted representation for target patch $j$
  • $\mathbf{z}_j \in \mathbb{R}^D$ is the target encoder's representation for target patch $j$
  • $\text{sg}(\cdot)$ denotes the stop-gradient operator: no gradients are backpropagated through $\mathbf{z}_j$
  • $|\mathcal{T}|$ is the total number of target patches (summed across all $M$ target blocks)
  • $\|\cdot\|_2^2$ is the squared L2 norm

In practice, the target representations $\mathbf{z}_j$ may be normalized (e.g., via layer normalization or L2 normalization) before the loss computation to stabilize training and prevent the targets from collapsing to zero or growing unboundedly.

The total training objective is simply $\mathcal{L}$—there are no additional regularization terms, contrastive losses, or reconstruction objectives. The simplicity of this loss is a hallmark of the JEPA framework.

WHY: L2 loss in representation space is preferred over alternatives for several reasons. Compared to contrastive losses (e.g., InfoNCE), it does not require negative samples, removing a source of computational overhead and hyperparameter sensitivity. Compared to reconstruction losses in input space, it avoids the capacity expenditure on low-level detail reproduction. The stop-gradient on $\mathbf{z}_j$ is essential: without it, the trivial solution of collapsing both representations to zero would minimize the loss. The normalization of targets is an additional safeguard against the target encoder producing representations that are easy to predict but uninformative (e.g., near-constant vectors).

5. Implementation Details

The following table summarizes the key hyperparameters used in 3D-JEPA. Where the paper does not explicitly report a value, we note that the value is inferred from I-JEPA defaults or standard point cloud Transformer practice, and mark these with an asterisk (*).

HyperparameterValueNotes
Input points1024Standard point cloud benchmark input size
Number of patches ($N$)64Selected via FPS
Points per patch ($K$)32KNN neighbors
Encoder layers ($L$)12Standard Transformer depth
Hidden dimension ($D$)384ViT-Small/Base scale
Attention heads ($h$)6$D/h = 64$ per head
FFN dimension1536$4 \times D$
Predictor layers ($L_p$)6*Lighter than encoder; inferred from I-JEPA conventions
Target blocks ($M$)4Ablation shows $M=4$ optimal
Masking ratio ($r$)~0.6~60% of patches are targets
OptimizerAdamW$\beta_1 = 0.9$, $\beta_2 = 0.999$*
Base learning rate1.5e-4*With cosine decay schedule
Warmup epochs10*Linear warmup
Total epochs300Pretraining on ShapeNet
Batch size128*Per-GPU or effective batch size
Weight decay0.05*Standard for Transformer pretraining
EMA momentum ($\tau$)0.996 → 1.0Cosine schedule over training
Pretraining datasetShapeNet~51K 3D shapes across 55 categories
GPUsNot explicitly reportedLikely 1–4 GPUs given dataset scale

Values marked with * are inferred from standard practice (I-JEPA defaults, Point-MAE/Point-BERT conventions) and may differ from the paper's exact configuration. No public repository is available for verification.

6. Algorithm

Algorithm 1: 3D-JEPA Pretraining (One Epoch)
Input: Dataset $\mathcal{D}$ of point clouds; context encoder $f_\theta$; target encoder $f_{\bar{\theta}}$; predictor $g_\phi$; EMA schedule $\tau(t)$; number of target blocks $M$; masking ratio $r$
Output: Updated parameters $\theta$, $\phi$, $\bar{\theta}$
1 for each mini-batch $\{\mathbf{X}_b\}_{b=1}^{B}$ in $\mathcal{D}$ do
2 for each $\mathbf{X}_b$ do
3 // Tokenize point cloud
4 $\{c_i\}_{i=1}^{N} \leftarrow \text{FPS}(\mathbf{X}_b, N)$ // Select N center points
5 $\{\mathbf{G}_i\}_{i=1}^{N} \leftarrow \text{KNN}(\mathbf{X}_b, \{c_i\}, K)$ // Group K neighbors per center
6 $\{\mathbf{t}_i\}_{i=1}^{N} \leftarrow \text{MiniPointNet}(\{\mathbf{G}_i - c_i\})$ // Embed patches, each t_i ∈ ℝ^D
7 $\{\mathbf{p}_i\}_{i=1}^{N} \leftarrow \text{PosEnc}(\{c_i\})$ // Positional embeddings from 3D coords
8 // Multi-block masking
9 $\mathcal{T}, \mathcal{C} \leftarrow \text{MultiBlockMask}(\{c_i\}, M, r)$ // Target and context index sets
10 // Context encoder (trainable): process context tokens only
11 $\{\mathbf{h}_i\}_{i \in \mathcal{C}} \leftarrow f_\theta(\{\mathbf{t}_i + \mathbf{p}_i\}_{i \in \mathcal{C}})$
12 // Target encoder (EMA, no grad): process target tokens
13 with no_grad():
14 $\{\mathbf{z}_j\}_{j \in \mathcal{T}} \leftarrow f_{\bar{\theta}}(\{\mathbf{t}_j + \mathbf{p}_j\}_{j \in \mathcal{T}})$
15 $\{\mathbf{z}_j\}_{j \in \mathcal{T}} \leftarrow \text{Normalize}(\{\mathbf{z}_j\})$ // Optional: LayerNorm or L2-norm targets
16 // Predictor: predict target representations from context
17 $\{\hat{\mathbf{z}}_j\}_{j \in \mathcal{T}} \leftarrow g_\phi(\{\mathbf{h}_i\}_{i \in \mathcal{C}}, \{\mathbf{p}_i\}_{i \in \mathcal{C}}, \{\mathbf{p}_j\}_{j \in \mathcal{T}})$
18 end for
19 // Compute loss over batch
20 $\mathcal{L} \leftarrow \frac{1}{B \cdot |\mathcal{T}|} \sum_{b=1}^{B} \sum_{j \in \mathcal{T}} \| \hat{\mathbf{z}}_j^{(b)} - \text{sg}(\mathbf{z}_j^{(b)}) \|_2^2$
21 // Update context encoder and predictor via gradient descent
22 $\theta, \phi \leftarrow \text{AdamW}(\nabla_{\theta, \phi} \mathcal{L})$
23 // Update target encoder via EMA
24 $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$
25 end for
Algorithm 2: Multi-Block Spatial Sampling in 3D
Input: Patch centers $\{c_i\}_{i=1}^{N} \subset \mathbb{R}^3$; number of target blocks $M$; target masking ratio $r$
Output: Target index set $\mathcal{T}$; context index set $\mathcal{C}$
1 $\mathcal{T} \leftarrow \emptyset$
2 $n_{\text{per\_block}} \leftarrow \lfloor N \cdot r / M \rfloor$ // Patches per target block
3 for $m = 1$ to $M$ do
4 $s_m \leftarrow \text{RandomSample}(\{1, \dots, N\} \setminus \mathcal{T})$ // Select seed not already in targets
5 // Compute Euclidean distances from seed to all other patch centers
6 $d_i \leftarrow \|c_i - c_{s_m}\|_2 \quad \forall i \in \{1, \dots, N\}$
7 // Select nearest n_per_block patches (including seed)
8 $\mathcal{B}_m \leftarrow \text{ArgTopK\_Smallest}(d, n_{\text{per\_block}})$
9 $\mathcal{T} \leftarrow \mathcal{T} \cup \mathcal{B}_m$
10 end for
11 $\mathcal{C} \leftarrow \{1, \dots, N\} \setminus \mathcal{T}$ // Context is all non-target patches
12 return $\mathcal{T}, \mathcal{C}$

7. Training

Step-by-Step: One Training Iteration

A single training iteration of 3D-JEPA proceeds through the following stages:

Step 1: Tokenization. A mini-batch of $B$ point clouds, each with $P=1024$ points, is loaded. For each point cloud, FPS selects $N=64$ center points, KNN groups $K=32$ neighbors per center, and the mini-PointNet encodes each group to produce $N$ tokens of dimension $D=384$. Positional embeddings derived from the 3D center coordinates are added. Output shape per sample: $N \times D = 64 \times 384$.

Step 2: Multi-block masking. For each sample, Algorithm 2 is invoked with $M=4$ target blocks and masking ratio $r \approx 0.6$. This produces approximately $N_t = \lfloor 64 \times 0.6 \rfloor = 38$ target patches and $N_c = 64 - 38 = 26$ context patches per sample. The exact split varies per sample due to block overlap.

Step 3: Context encoding. The $N_c$ context tokens (with positional embeddings) are fed through the 12-layer Transformer context encoder $f_\theta$. Self-attention is computed only over the context tokens. Output: $N_c \times D$ context representations per sample. This step is the most compute-intensive part of the forward pass.

Step 4: Target encoding (no gradient). Under torch.no_grad(), the target encoder $f_{\bar{\theta}}$ processes the $N_t$ target tokens (with positional embeddings). Output: $N_t \times D$ target representations per sample. These are optionally normalized (LayerNorm or L2-norm) to stabilize the training signal.

Step 5: Prediction. The predictor $g_\phi$ receives the context representations (with their positional embeddings) and learnable mask tokens (with the target positions' positional embeddings). It processes the concatenated sequence through its $L_p$ Transformer layers and outputs predicted representations at the target positions. Output: $N_t \times D$ predicted representations per sample.

Step 6: Loss computation. The L2 loss is computed between the predictor's output $\hat{\mathbf{z}}_j$ and the (stop-gradiented) target encoder's output $\mathbf{z}_j$ at each target position, then averaged over all target positions and all samples in the batch.

Step 7: Backward pass and parameter update. Gradients are computed with respect to $\theta$ (context encoder) and $\phi$ (predictor) only. AdamW updates both parameter sets. The learning rate follows a cosine decay schedule with linear warmup.

Step 8: EMA update. The target encoder parameters are updated: $\bar{\theta} \leftarrow \tau(t) \bar{\theta} + (1-\tau(t)) \theta$, where $\tau(t)$ follows a cosine schedule from 0.996 to 1.0.

Training Diagram

3D-JEPA: Detailed Training Flow (One Iteration) Point Cloud B×1024×3 Tokenizer FPS+KNN+MiniPN Tokens B×64×384 Multi-Block Mask M=4, r=0.6 Context Tokens B × N_c × 384 (N_c≈26) Target Tokens B × N_t × 384 (N_t≈38) Context Encoder 12-layer Transformer TRAINABLE (∇θ) Target Encoder 12-layer Transformer FROZEN (EMA, sg) Context Repr: B×N_c×384 Target Repr: B×N_t×384 Predictor 6-layer Transformer + mask tokens TRAINABLE (∇φ) + pos embeddings Predicted Repr: B×N_t×384 L2 Loss ||ẑ - sg(z)||² stop-gradient EMA: θ̄ ← τθ̄ + (1-τ)θ Gradient Flow ━━ Trainable path (∇) ╌╌ Frozen path (no grad)
Figure 3. Detailed training flow for one iteration of 3D-JEPA. Solid green borders and arrows denote trainable components through which gradients flow. Dashed borders and arrows denote the frozen (EMA) target encoder with stop-gradient. Dimension annotations show tensor shapes at each stage for a batch of $B$ samples with $N=64$ patches, $D=384$.

Training Dynamics and Schedule

Pretraining is conducted on ShapeNet for 300 epochs. The learning rate follows a cosine decay schedule with linear warmup over the first 10 epochs. The EMA momentum $\tau$ increases from 0.996 to 1.0 via a cosine schedule, meaning the target encoder tracks the context encoder closely at the beginning of training (fast updates) and stabilizes near the end (almost frozen). Weight decay of 0.05 is applied to non-bias, non-normalization parameters as standard for AdamW-based Transformer training.

Data augmentation for point clouds during pretraining typically includes random rotation (around the vertical axis), random scaling (0.67–1.5×), random translation, and random point jittering. These augmentations are applied to the raw point cloud before tokenization.

8. Inference

After pretraining, only the context encoder $f_\theta$ (or equivalently the target encoder $f_{\bar{\theta}}$, which has accumulated the EMA of $f_\theta$'s parameters) is retained for downstream tasks. The predictor $g_\phi$ is discarded—it was a training scaffold and has no role at inference time. Critically, no masking is applied at inference: the full set of $N$ tokens is fed to the encoder.

Downstream Protocols

Linear probing. The pretrained encoder is frozen. A linear classifier is trained on top of the encoder's output (typically the [CLS] token or mean-pooled patch representations) using the downstream task's labeled data and cross-entropy loss.

Fine-tuning. The pretrained encoder's weights initialize a new model. A task-specific head (linear layer for classification, segmentation head for part segmentation) is added, and the entire model is trained end-to-end on the downstream dataset with a lower learning rate for the pretrained parameters.

Few-shot evaluation. A small number of labeled examples per class are used to train either a linear classifier or a nearest-neighbor classifier (k-NN) on the frozen encoder's representations, evaluating the quality of the learned features under data scarcity.

Inference Pipeline Diagram

3D-JEPA: Inference Pipeline (Downstream Tasks) Point Cloud 1024 × 3 Tokenizer FPS+KNN+MiniPN All Tokens N × D (no masking) Pretrained Encoder 12-layer Transformer from pretraining (f_θ or f_θ̄) Representations N × D Linear Probe Frozen encoder + linear head Classification (ModelNet40) Fine-Tuning End-to-end with task head ScanObjectNN, ShapeNetPart Few-Shot / k-NN Frozen encoder + k-NN Low-shot evaluation Predictor $g_\phi$ is DISCARDED at inference No masking applied — full N tokens processed by encoder
Figure 4. 3D-JEPA inference pipeline. The predictor and masking strategy are discarded. The pretrained encoder processes all $N$ tokens (no masking) to produce representations, which are then used for downstream tasks via linear probing, fine-tuning, or few-shot evaluation. The encoder may be either the final context encoder $f_\theta$ or the EMA target encoder $f_{\bar{\theta}}$.

9. Results & Benchmarks

3D-JEPA is evaluated on three standard 3D understanding benchmarks: ModelNet40 (shape classification), ScanObjectNN (real-world object classification), and ShapeNetPart (part segmentation). All models are pretrained on ShapeNet and then transferred.

Shape Classification — ModelNet40

ModelNet40 contains 12,311 CAD models across 40 categories (9,843 train / 2,468 test). Evaluation uses overall accuracy (OA).

MethodPretrainingOA (vote)
Transformer (scratch)None91.4%
Point-BERTdVAE + MLM93.2%
Point-MAEPoint reconstruction93.2%
Point-M2AEMulti-scale reconstruction93.4%
3D-JEPARepresentation prediction93.4%

3D-JEPA matches or slightly exceeds Point-MAE and Point-BERT without ever reconstructing point coordinates, demonstrating that representation-space prediction captures sufficient geometric and semantic information for shape recognition.

Real-World Classification — ScanObjectNN

ScanObjectNN contains ~15,000 real-world scanned objects across 15 categories. The hardest variant (PB_T50_RS) includes background clutter, perturbations, and translation, making it significantly more challenging than synthetic benchmarks.

MethodOBJ_BGOBJ_ONLYPB_T50_RS (hardest)
Transformer (scratch)79.9%80.6%77.2%
Point-BERT87.4%88.1%83.1%
Point-MAE90.0%88.3%85.2%
Point-M2AE91.2%88.8%86.4%
3D-JEPA91.5%89.7%86.8%

On the hardest ScanObjectNN variant (PB_T50_RS), 3D-JEPA achieves 86.8%, surpassing Point-MAE by 1.6 percentage points. This gap is more pronounced than on the synthetic ModelNet40, suggesting that representation-space prediction is particularly advantageous when dealing with noisy, real-world 3D data where exact point reconstruction is less meaningful.

Part Segmentation — ShapeNetPart

ShapeNetPart contains 16,881 shapes across 16 categories, with 50 part labels total. Evaluation uses class-mean and instance-mean IoU.

MethodCat. mIoUInst. mIoU
Transformer (scratch)83.4%85.1%
Point-BERT84.1%85.6%
Point-MAE84.2%86.1%
3D-JEPA84.4%86.3%

Improvements on part segmentation are more modest, consistent with prior observations that self-supervised pretraining benefits are smaller on ShapeNetPart due to the overlap between pretraining data (ShapeNet) and segmentation data.

Ablations

Number of target blocks ($M$). The most important ablation validates the multi-block design:

Target Blocks ($M$)ScanObjectNN PB_T50_RSModelNet40
$M = 1$84.6%92.9%
$M = 2$85.7%93.1%
$M = 4$86.8%93.4%
$M = 8$86.3%93.2%

The jump from $M=1$ to $M=4$ is substantial (2.2 points on ScanObjectNN), confirming that multi-block sampling provides a harder, more informative prediction task than single-block masking. Beyond $M=4$, performance plateaus or slightly decreases, likely because too many small blocks reduce the spatial coherence of each target region.

Masking ratio ($r$). Performance is relatively stable across $r \in [0.5, 0.7]$, with $r=0.6$ being the sweet spot. Too little masking ($r < 0.4$) makes the prediction task too easy; too much masking ($r > 0.8$) leaves insufficient context for meaningful prediction.

Representation vs. reconstruction. Replacing the representation-space loss with a point-coordinate reconstruction loss (making the method equivalent to a multi-block Point-MAE variant) degrades ScanObjectNN performance by approximately 1–2 points, providing direct evidence for the advantage of latent prediction in the 3D domain.

10. Connection to JEPA Family

Lineage

3D-JEPA sits on a clear lineage within the JEPA family:

  1. JEPA (LeCun, 2022): The conceptual framework—predict in representation space, not input space—using an energy-based formulation with joint embedding architectures.
  2. I-JEPA (Assran et al., 2023): The first concrete instantiation for 2D images. Established the context-encoder / EMA-target-encoder / predictor / multi-block-masking recipe that 3D-JEPA inherits.
  3. 3D-JEPA (Hu et al., 2024): Adapts I-JEPA to 3D point clouds, addressing the fundamental challenges of irregular geometry, unordered points, and spatial tokenization.

Parallel to 3D-JEPA, other JEPA variants extend the framework to different modalities: Audio-JEPA for spectrograms, MC-JEPA for video with motion compensation, and Point-JEPA (a related but distinct approach to 3D point clouds). 3D-JEPA's specific contribution is demonstrating that the I-JEPA recipe transfers to 3D with appropriate tokenization and spatial masking adaptations.

Key Contribution: Multi-Block Masking in 3D Euclidean Space

While I-JEPA's multi-block masking operates on a regular 2D grid (rows and columns of image patches), 3D-JEPA must define spatial contiguity in continuous 3D Euclidean space. The paper's primary novelty is demonstrating that KNN-based block expansion from seed patches in $\mathbb{R}^3$ effectively creates spatially coherent target regions that require non-trivial 3D reasoning to predict. This is not a trivial adaptation: the irregular, non-grid structure of point clouds means that block shapes are amorphous (not rectangular), block sizes vary with local point density, and the context-target boundary is complex and three-dimensional. The ablation showing that $M=4$ blocks outperform $M=1$ validates that the multi-block principle discovered in 2D carries over to 3D, establishing a modality-general design principle for the JEPA family.

Relationship to Reconstruction-Based 3D Methods

3D-JEPA stands in explicit contrast to the dominant reconstruction-based paradigm in 3D self-supervised learning:

  • Point-BERT uses a discrete variational autoencoder (dVAE) to tokenize point clouds, then trains a BERT-like model to predict masked discrete tokens. This requires a separately trained tokenizer and optimizes a cross-entropy loss over a discrete codebook.
  • Point-MAE masks point patches and reconstructs the missing 3D coordinates via Chamfer distance. The decoder must predict exact point positions.
  • 3D-JEPA replaces both approaches with a single, unified representation-space prediction objective. No tokenizer pretraining, no point reconstruction, no decoder.

The consistently competitive results against these reconstruction-based methods provide strong evidence that, as conjectured by the JEPA framework, representation-space prediction is at least as effective as input-space reconstruction for learning transferable features—and potentially superior for robustness to noise and domain shift, as evidenced by the larger gains on the real-world ScanObjectNN benchmark.

Influence

3D-JEPA establishes the JEPA framework as a viable paradigm for 3D self-supervised learning, alongside reconstruction-based and contrastive approaches. Its success suggests that future 3D representation learning methods can benefit from operating entirely in latent space, potentially enabling more efficient pretraining (no decoder needed) and better downstream transfer (representations are inherently abstract). The multi-block masking strategy in 3D may inform future work on other irregular data modalities such as graphs, meshes, and molecular structures.

11. Summary

Key Takeaway. 3D-JEPA demonstrates that the Joint Embedding Predictive Architecture—predicting in representation space rather than reconstructing input coordinates—transfers effectively from 2D images to 3D point clouds. By adapting I-JEPA's multi-block masking strategy to 3D Euclidean space via KNN-based spatial block sampling, 3D-JEPA achieves results competitive with or superior to established reconstruction-based methods (Point-MAE, Point-BERT) on ModelNet40, ScanObjectNN, and ShapeNetPart, with the largest gains on the challenging real-world ScanObjectNN benchmark (+1.6 points over Point-MAE on PB_T50_RS). The ablation confirming that $M=4$ target blocks outperforms $M=1$ by 2.2 points validates the multi-block principle as a modality-general design choice within the JEPA family. 3D-JEPA's simplicity—no codebook, no decoder, no contrastive negatives—makes it an attractive foundation for future 3D representation learning research.

12. References

  1. Hu, T., Cheng, Z., Xie, H., Li, X., & Zhu, J. (2024). 3D-JEPA: A Joint Embedding Predictive Architecture for 3D Self-Supervised Representation Learning. arXiv preprint arXiv:2409.15803.
  2. Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.
  3. LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview.
  4. Pang, Y., Wang, W., Tay, F. E. H., Liu, W., Tian, Y., & Yuan, L. (2022). Masked Autoencoders for Point Cloud Self-supervised Learning. ECCV 2022.
  5. Yu, X., Tang, L., Rao, Y., Huang, T., Zhou, J., & Lu, J. (2022). Point-BERT: Pre-training 3D Point Cloud Transformers with Masked Point Modeling. CVPR 2022.
  6. Zhang, R., Guo, Z., Zhang, W., Li, K., Miao, X., Cui, B., Qiao, Y., Gao, P., & Li, H. (2022). Point-M2AE: Multi-scale Masked Autoencoders for Hierarchical Point Cloud Pre-training. NeurIPS 2022.
  7. Qi, C. R., Su, H., Mo, K., & Guibas, L. J. (2017). PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. CVPR 2017.
  8. Qi, C. R., Yi, L., Su, H., & Guibas, L. J. (2017). PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space. NeurIPS 2017.
  9. Uy, M. A., Pham, Q.-H., Hua, B.-S., Nguyen, D. T., & Yeung, S.-K. (2019). Revisiting Point Cloud Classification: A New Benchmark Dataset and Classification Model on Real-World Data. ICCV 2019. [ScanObjectNN]
  10. Wu, Z., Song, S., Khosla, A., Yu, F., Zhang, L., Tang, X., & Xiao, J. (2015). 3D ShapeNets: A Deep Representation for Volumetric Shapes. CVPR 2015. [ModelNet40]
  11. Yi, L., Kim, V. G., Ceylan, D., Shen, I., Yan, M., Su, H., Lu, C., Huang, Q., Sheffer, A., & Guibas, L. (2016). A Scalable Active Framework for Region Annotation in 3D Shape Collections. SIGGRAPH Asia 2016. [ShapeNetPart]
  12. Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., Doersch, C., Pinto, B. A., Zheng, Z., Azizi, M. O., et al. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.