AuthorsSaito, Kudeshia, Poovvancheri
Date2024-04
Category3D / Spatial
Derives fromI-JEPA

Point-JEPA: A Joint Embedding Predictive Architecture for Self-Supervised Learning on Point Cloud

Saito, Y., Kudeshia, A., & Poovvancheri, J. (2024). arXiv:2404.16432

1. Introduction

Three-dimensional point clouds are a fundamental representation for spatial perception in robotics, autonomous driving, and scene understanding. Unlike images, which lie on regular 2D grids, point clouds are inherently unordered, sparse, and irregularly distributed in continuous 3D space. These structural differences make it non-trivial to transfer advances in self-supervised learning (SSL) — particularly the recent success of Joint Embedding Predictive Architectures (JEPA) — from the image domain to the point cloud domain.

Prior to Point-JEPA, self-supervised methods for point clouds primarily followed two paradigms: contrastive learning (e.g., PointContrast, CrossPoint) and masked autoencoding (e.g., Point-MAE, Point-BERT). Contrastive methods learn by pulling together augmented views of the same object while pushing apart different objects, but they require careful augmentation design and can suffer from representation collapse without negative sampling or asymmetric architecture tricks. Masked autoencoders reconstruct raw point coordinates in input space, forcing the encoder to attend to low-level geometric detail rather than semantic abstractions — a problem well-documented in the 2D domain by Assran et al. (2023) when motivating I-JEPA.

Point-JEPA (Saito, Kudeshia, & Poovvancheri, 2024) proposes a direct adaptation of the I-JEPA framework to 3D point clouds, replacing pixel-patch tokenization with point-cloud patch extraction via farthest point sampling (FPS) and $k$-nearest neighbor ($k$-NN) grouping, and replacing 2D spatial block masking with 3D spatial region masking. The core thesis is that predicting abstract latent representations of masked 3D regions — rather than reconstructing raw $(x,y,z)$ coordinates — encourages the model to learn high-level semantic and geometric features suitable for downstream tasks such as object classification and scene understanding.

Key distinction from I-JEPA: Where I-JEPA operates on a regular 2D image grid and can mask rectangular blocks of contiguous patch tokens, Point-JEPA must handle irregular 3D geometry. Patches are not grid-aligned cells but variable-density neighborhoods in Euclidean space, and masking must respect spatial locality in three dimensions rather than two. This seemingly minor change requires rethinking how context and target regions are defined, how positional information is encoded, and how the predictor conditions on spatial location.

The contributions of Point-JEPA are threefold: (1) it demonstrates that the JEPA paradigm — prediction in latent space with an EMA target encoder and no pixel/point-level reconstruction — transfers effectively to 3D point clouds; (2) it introduces a 3D spatial region masking strategy that respects the irregular geometry of point clouds; and (3) it evaluates the approach on both synthetic benchmarks (ModelNet40, ShapeNet) and real-world LiDAR data (ScanObjectNN, nuScenes), showing competitive or superior performance compared to contrastive and masked-autoencoder baselines under linear evaluation and fine-tuning protocols.

2. Method

Point-JEPA's method can be understood through a spatial reasoning analogy. Imagine you are standing in a room and someone covers several regions of the room with opaque screens. You can still see parts of the room — furniture, walls, floor — through the gaps. Now, instead of asking you to draw what is behind each screen (pixel-level reconstruction), someone asks you to describe it at a conceptual level: "there's probably a chair there, oriented toward the window, about one meter tall." This is the essence of JEPA-style prediction: reasoning about hidden content in a semantic feature space rather than reconstructing exact sensory input.

Intuition — Why predict in latent space for 3D? Point clouds from LiDAR sensors are noisy, sparse, and variable-density. A point at 50 meters might have very few neighbors, while a nearby surface might be densely sampled. Forcing a model to reconstruct exact 3D coordinates of masked points would tie its representations to sensor-specific density patterns and noise characteristics rather than the underlying geometry and semantics. By predicting in latent space, Point-JEPA abstracts away these sensor artifacts and focuses on what the masked region means rather than its exact point distribution.

The method proceeds in four stages for each training sample:

  1. Patch extraction: The raw point cloud is tokenized into a set of local patches using farthest point sampling to select patch centers, then $k$-nearest neighbor grouping to collect points around each center. Each patch is embedded into a $D$-dimensional token via a lightweight point embedding module (e.g., a small PointNet or linear projection on local coordinates).
  2. 3D spatial masking: A subset of patch tokens is designated as the target set $\mathcal{T}$ using a spatial region masking strategy — selecting seed centers in 3D and expanding to nearby patches. The complementary visible patches form the context set $\mathcal{C}$. Unlike image JEPA, this masking operates in continuous 3D space rather than on a discrete grid.
  3. Encoding: The context patches $\mathcal{C}$ are processed by the online encoder (a Transformer), producing context representations. Simultaneously, all patches (including targets) are processed by the target encoder — a momentum-updated (EMA) copy of the online encoder — producing target representations. The target encoder receives no gradients; it is updated only via exponential moving average of the online encoder's weights.
  4. Prediction and loss: A lightweight predictor network takes the context representations and, conditioned on the 3D positions of the masked target patches, predicts what the target encoder's representations of those masked patches should be. The training loss is the mean squared error between the predictor's outputs and the (stop-gradient) target representations.
Intuition — The role of the predictor: The predictor is intentionally narrow (lower dimensionality than the encoder). This bottleneck prevents a degenerate shortcut where the predictor could simply copy context features without understanding spatial relationships. It must compress and reason about what lies at a particular 3D location, given the visible surroundings. This is analogous to requiring a concise summary rather than allowing copy-paste — the compression forces genuine understanding.

Crucially, Point-JEPA uses no data augmentation beyond the masking itself, no negative pairs, and no pixel/point-level reconstruction decoder. The entire learning signal comes from predicting latent representations of spatially masked regions. Collapse is mitigated through the combination of EMA target updates, the asymmetric encoder–predictor architecture, and the narrow predictor bottleneck — though it should be noted that the precise mechanisms by which these components interact to prevent collapse remain an active area of theoretical investigation in the broader SSL literature.

3. Model Overview

At-a-Glance

ComponentDetail
Input3D point clouds (LiDAR or object-level); $N$ points with $(x, y, z)$ coordinates (optionally with normals or intensity)
TokenizationFPS center selection + $k$-NN grouping → $M$ patches, each embedded to $\mathbb{R}^D$
Masking3D spatial region masking: seed-and-expand in Euclidean space; ~60–75% of patches masked as targets
Online EncoderStandard Transformer (e.g., 12 layers, 384-dim, 6 heads) operating on context patches only
Target EncoderEMA copy of online encoder; processes all patches; no gradient
PredictorNarrow Transformer (e.g., 6 layers, 192-dim) mapping context representations → target predictions conditioned on 3D positional tokens
Loss$\ell_2$ (MSE) in latent space between predictor outputs and target encoder representations
Key ResultCompetitive with Point-MAE and contrastive baselines on ModelNet40 and ScanObjectNN under linear probing
Parameters~22M (encoder) + ~5M (predictor); varies with configuration

Training Architecture Diagram

Point-JEPA Training Architecture Point Cloud N × 3 (x,y,z) Patch Tokenizer FPS + kNN + Embed → M × D tokens 3D Spatial Mask Seed + Expand ~60-75% masked Context C All Patches |C| × D M × D Context Encoder Transformer (12L) ▲ trainable Target Encoder EMA (frozen) ⊘ no gradient EMA |C| × D Predictor Narrow Transformer (6L, D/2) ▲ trainable 3D Pos Tokens target positions |T| × D select T |T| × D sg(·) ℓ₂ Loss MSE in latent ∇ gradient flows to encoder + predictor only
Figure 1: Point-JEPA training architecture. The raw point cloud is tokenized into patches, then split into context and target sets via 3D spatial masking. The context encoder (trainable) processes visible patches; the target encoder (EMA, frozen) processes all patches. The predictor (trainable, narrow) predicts target representations conditioned on 3D positional tokens. Loss is computed in latent space. Gradients flow only to the context encoder and predictor.

4. Main Components of Point-JEPA

4.1 Point Cloud Patch Tokenizer

WHAT: The patch tokenizer converts a raw point cloud of $N$ points into a fixed set of $M$ patch tokens, each represented as a $D$-dimensional vector. This is the 3D analog of the image patch embedding in Vision Transformers, but adapted for irregular point distributions.

HOW: The tokenization process uses a two-stage approach:

  1. Farthest Point Sampling (FPS): From the input point cloud $\mathcal{P} = \{p_i\}_{i=1}^{N} \subset \mathbb{R}^3$, FPS iteratively selects $M$ center points $\{c_j\}_{j=1}^{M}$ that are maximally spread across the 3D volume. At each step, the point farthest from all previously selected centers is chosen, yielding approximately uniform spatial coverage regardless of point density.
  2. $k$-Nearest Neighbor Grouping: For each center $c_j$, the $k$ nearest points are gathered to form a local patch $\mathcal{P}_j = \{p \in \mathcal{P} : p \in \text{kNN}(c_j, k)\}$. Coordinates within each patch are typically normalized relative to the center: $p' = p - c_j$.
  3. Patch Embedding: Each local patch is embedded into $\mathbb{R}^D$ via a lightweight network — typically a mini-PointNet consisting of shared MLPs followed by max-pooling, or a linear projection on concatenated local coordinates. The center coordinates $c_j$ are retained for positional encoding.

Typical hyperparameters reported in Point-JEPA: $M = 64$ or $128$ patch centers, $k = 32$ neighbors per patch, embedding dimension $D = 384$.

WHY: FPS ensures coverage of the full 3D extent, while $k$-NN grouping creates patches with consistent local geometry. The paper notes that this tokenization strategy is critical for adapting the JEPA framework to point clouds — unlike images where patches are axis-aligned rectangles, point cloud patches are volumetric neighborhoods that naturally adapt to surface geometry and density variations. Alternative tokenization strategies (e.g., random sampling, voxelization) were considered; FPS + $k$-NN is standard in point cloud Transformers (Zhao et al., 2021; Pang et al., 2022) and is adopted without modification for Point-JEPA.

4.2 Context Encoder (Online Encoder)

WHAT: The context encoder $f_\theta$ is a standard Transformer that processes only the visible (unmasked) context patches to produce rich representations. This is the primary trainable component of Point-JEPA and the module used for downstream tasks after pretraining.

HOW: The encoder is a standard Vision Transformer architecture adapted for point cloud tokens:

  • Input: context patch tokens $\{z_j^{(0)}\}_{j \in \mathcal{C}} \in \mathbb{R}^{|\mathcal{C}| \times D}$, where $z_j^{(0)}$ is the patch embedding plus 3D positional encoding for patch $j$
  • Architecture: $L$ Transformer layers with multi-head self-attention (MHSA) and feed-forward networks (FFN)
  • Positional encoding: 3D sinusoidal or learned embeddings derived from FPS center coordinates $(x_j, y_j, z_j)$
  • Output: context representations $\{h_j\}_{j \in \mathcal{C}} \in \mathbb{R}^{|\mathcal{C}| \times D}$

Following Point-JEPA's description, the encoder configuration follows a ViT-Base-like setup: $L = 12$ layers, $D = 384$ dimensions, 6 attention heads, FFN expansion ratio of 4 (hidden dimension 1536), GELU activations, and LayerNorm.

WHY: Processing only context patches (not masked ones) has two benefits: (1) it reduces computational cost since the Transformer's self-attention is quadratic in sequence length, and (2) it forces the encoder to build representations that are useful even when large portions of the input are missing — encouraging robustness and semantic completeness. Point-JEPA follows I-JEPA in this design choice rather than processing mask tokens as placeholders (as in MAE-style approaches).

4.3 Target Encoder (EMA)

WHAT: The target encoder $f_{\bar{\theta}}$ is a copy of the online encoder whose parameters $\bar{\theta}$ are updated via exponential moving average (EMA) of the online encoder's parameters $\theta$. It processes all patch tokens (not just context) to produce the prediction targets.

HOW: The EMA update rule is:

$$\bar{\theta} \leftarrow \tau \bar{\theta} + (1 - \tau) \theta$$

where $\tau \in [0, 1)$ is the momentum coefficient, typically following a cosine schedule from $\tau_0 = 0.996$ to $\tau_1 = 1.0$ over the course of training. Early in training, $\tau$ is lower, allowing faster integration of the online encoder's rapidly changing parameters; as training stabilizes, $\tau$ approaches 1.0, making the target encoder a very slow-moving average.

The target encoder processes all $M$ patch tokens (both context and target patches), producing representations for every patch. The representations corresponding to the target patch indices are then extracted and used as prediction targets. Crucially, all operations involving the target encoder are within a stop_gradient (or torch.no_grad()) context — no gradients flow through the target encoder.

WHY: The EMA target encoder provides slowly evolving, stable prediction targets. Without this mechanism (i.e., if targets were produced by the same network receiving gradients), the system could collapse to trivial solutions where both encoder and predictor learn constant outputs. The EMA approach, borrowed from BYOL (Grill et al., 2020) and validated in I-JEPA (Assran et al., 2023), creates a bootstrap effect where the online encoder chases a smoother, temporally averaged version of itself. It should be noted that EMA alone is not a formal guarantee against collapse — the combination of EMA, the predictor bottleneck, and the masking strategy together constitute the collapse-avoidance mechanism, and the precise theoretical understanding of their interaction remains an open question.

4.4 Predictor

WHAT: The predictor $g_\phi$ is a narrow Transformer network that takes context representations from the online encoder and, conditioned on the 3D positions of masked target patches, predicts what the target encoder's representations of those patches should be.

HOW: The predictor operates as follows:

  1. Context representations $\{h_j\}_{j \in \mathcal{C}}$ from the online encoder are projected into the predictor's (narrower) dimension $D_p < D$.
  2. For each target patch index $t \in \mathcal{T}$, a learnable or sinusoidal 3D positional token $q_t \in \mathbb{R}^{D_p}$ is created based on the target patch's center coordinates. These are sometimes called mask tokens or query tokens.
  3. The predictor processes the concatenation of projected context representations and positional query tokens through $L_p$ Transformer layers with cross-attention (queries attend to context) or joint self-attention.
  4. The outputs corresponding to target query positions are projected back to dimension $D$ and constitute the predictions $\{\hat{s}_t\}_{t \in \mathcal{T}}$.

Key hyperparameters: $L_p = 6$ layers, $D_p = 192$ (half the encoder dimension), 6 attention heads. The narrowness of the predictor ($D_p \ll D$) is a deliberate bottleneck.

WHY: The narrow predictor serves multiple roles: (1) it creates an information bottleneck that prevents the trivial solution of copying context features to nearby positions; (2) it forces the context encoder to produce representations that are sufficiently rich to support prediction through a compressed channel; and (3) it is discarded after pretraining, so its parameters do not add to the downstream model's computational cost. Ablation studies in I-JEPA (and echoed in Point-JEPA) show that removing the bottleneck or using too-wide predictors leads to degraded representations, supporting the hypothesis that the compression forces the encoder to develop more informative features.

4.5 Masking Strategy (3D Spatial Region Masking)

WHAT: Point-JEPA extends I-JEPA's multi-block masking from 2D image grids to 3D point cloud space. The masking strategy selects spatially coherent regions of the point cloud as prediction targets, ensuring that the model must reason about 3D spatial structure rather than interpolating from immediate neighbors.

HOW: The 3D spatial region masking proceeds as:

  1. Seed selection: $K$ seed points are randomly selected from the $M$ FPS centers (typically $K = 4$ target regions per sample).
  2. Region expansion: For each seed, nearby patches within a radius $r$ in Euclidean 3D space are added to the target set. Alternatively, a fixed number of nearest patches $n_{\text{expand}}$ are selected per seed.
  3. Target set assembly: The union of all expanded regions forms the target set $\mathcal{T}$. Overlap between regions is resolved by deduplication. The target ratio is controlled to be approximately 60–75% of all patches.
  4. Context set: The remaining patches form the context $\mathcal{C} = \{1, \ldots, M\} \setminus \mathcal{T}$, typically 25–40% of patches.

The high masking ratio is critical: by making the model predict a majority of the point cloud from a minority of visible patches, the task becomes non-trivial and requires genuine understanding of 3D structure.

3D Spatial Region Masking Strategy 1. All Patch Centers (M=64) 2. Seed + Expand (K=4) seed₁ seed₂ seed₃ 3. Context / Target Split Legend Context patch (visible) Target patch (masked) Seed center Expansion radius 3D Spatial Masking Properties • Operates in continuous Euclidean space (not grid) • K=4 seed regions per sample (typically) • ~60-75% of patches masked as targets • Regions are spatially contiguous → forces semantic reasoning • Adapts to density: sparse regions yield larger spatial coverage • No augmentation beyond masking needed
Figure 2: 3D spatial region masking in Point-JEPA. (1) All patch centers from FPS are shown. (2) K=4 seed centers are randomly selected and expanded to nearby patches within a radius, forming target regions. (3) The resulting context/target split: solid circles are visible context patches fed to the encoder; dashed circles are masked target patches whose representations must be predicted.

WHY: Spatial region masking (as opposed to random token masking) ensures that the masked regions are contiguous and semantically coherent — covering, for example, an entire wheel of a car or a section of a wall rather than scattered individual points. This forces the predictor to reason about 3D structure at a high level. The paper reports that spatially coherent masking significantly outperforms random masking (typically by 2–4% on linear probing accuracy), consistent with findings in I-JEPA for 2D images. The high masking ratio (60–75%) is also essential: lower ratios make the task too easy (nearby visible patches provide trivial interpolation cues), while higher ratios leave insufficient context for meaningful prediction.

4.6 Loss Function

WHAT: Point-JEPA minimizes the mean squared error between the predictor's outputs and the target encoder's representations for masked patches.

HOW: Given a batch of $B$ point clouds, for each sample $i$ with target patch indices $\mathcal{T}_i$ and context patch indices $\mathcal{C}_i$, the loss is:

$$\mathcal{L} = \frac{1}{B} \sum_{i=1}^{B} \frac{1}{|\mathcal{T}_i|} \sum_{t \in \mathcal{T}_i} \left\| \hat{s}_t^{(i)} - \text{sg}\!\left(s_t^{(i)}\right) \right\|_2^2$$

where:

  • $B$ is the batch size
  • $\mathcal{T}_i$ is the set of target patch indices for sample $i$
  • $\hat{s}_t^{(i)} = g_\phi\!\left(f_\theta(\{z_j\}_{j \in \mathcal{C}_i}),\; \text{pos}(t)\right) \in \mathbb{R}^D$ is the predictor's output for target patch $t$ in sample $i$, conditioned on the context encoder's representations and the 3D position of patch $t$
  • $s_t^{(i)} = f_{\bar{\theta}}(\{z_j\}_{j=1}^{M})[t] \in \mathbb{R}^D$ is the target encoder's representation for patch $t$ in sample $i$, extracted from the full set of patch representations
  • $\text{sg}(\cdot)$ denotes the stop-gradient operator, ensuring no gradient flows through the target encoder
  • $\|\cdot\|_2^2$ is the squared $\ell_2$ norm
  • $f_\theta$ is the online (context) encoder with trainable parameters $\theta$
  • $f_{\bar{\theta}}$ is the target encoder with EMA parameters $\bar{\theta}$
  • $g_\phi$ is the predictor with trainable parameters $\phi$
  • $\text{pos}(t) \in \mathbb{R}^{D_p}$ is the 3D positional encoding of target patch $t$'s center coordinates

In some formulations, the target representations are additionally layer-normalized before computing the loss, which stabilizes training by preventing the target encoder from driving representations toward large-magnitude vectors. When target normalization is applied:

$$\mathcal{L} = \frac{1}{B} \sum_{i=1}^{B} \frac{1}{|\mathcal{T}_i|} \sum_{t \in \mathcal{T}_i} \left\| \hat{s}_t^{(i)} - \text{sg}\!\left(\text{LN}(s_t^{(i)})\right) \right\|_2^2$$

where $\text{LN}(\cdot)$ denotes layer normalization applied per-token.

WHY: The $\ell_2$ loss in latent space is the simplest choice that aligns predicted and target representations without introducing the complexities of contrastive losses (which require negative pairs) or reconstruction losses (which operate in input space). The per-target-patch averaging ensures that the loss is not dominated by samples with more target patches. Layer normalization of targets, when used, helps stabilize training dynamics — without it, the target encoder's representations can drift to large magnitudes, which the predictor may exploit rather than learning useful features. However, it should be noted that the precise effect of target normalization on collapse prevention is empirical rather than theoretically guaranteed.

4.7 3D Positional Encoding

WHAT: A critical component specific to Point-JEPA is how 3D spatial positions are encoded and injected into the Transformer architecture. Unlike images where positions lie on a regular 2D grid, point cloud patch centers occupy arbitrary locations in continuous 3D space.

HOW: Point-JEPA uses sinusoidal positional encodings extended to three dimensions. For a patch center at coordinates $(x, y, z)$, the positional encoding is the concatenation of per-axis sinusoidal features:

$$\text{PE}(x, y, z) = [\text{sin}(x / \omega_1), \text{cos}(x / \omega_1), \ldots, \text{sin}(z / \omega_{D/6}), \text{cos}(z / \omega_{D/6})]$$

where $\omega_k$ are frequency bases (typically geometric progression), and the total encoding has dimension $D$ (equally split across the three axes, $D/3$ per axis). These encodings are added to the patch embeddings before being input to both the context and target encoders, and are used as positional queries for the predictor's target tokens.

WHY: The positional encoding is particularly important for the predictor, which must know where in 3D space the target patches are located to make meaningful predictions. Without positional conditioning, the predictor would have no way to distinguish which masked patch it should predict — the loss would become ill-posed. The sinusoidal encoding provides smooth, continuous positional information that generalizes across scales and spatial extents, which is important for LiDAR data where the spatial extent varies significantly across scenes.

5. Implementation Details

HyperparameterValueNotes
Encoder layers12Standard Transformer blocks
Encoder dimension384ViT-Small/Base scale
Encoder attention heads6Head dimension = 64
FFN hidden dim15364× expansion ratio
Predictor layers6Narrower Transformer
Predictor dimension192Half of encoder dim
Predictor heads6Head dimension = 32
Number of patches ($M$)64 or 128Via FPS center selection
Points per patch ($k$)32Via $k$-NN grouping
Input points ($N$)1024 or 2048Dataset-dependent
Masking ratio60–75%Spatially coherent regions
Number of target regions ($K$)4Seeds for region expansion
OptimizerAdamW$\beta_1 = 0.9$, $\beta_2 = 0.999$
Learning rate1.5 × 10⁻⁴Peak LR after warmup
LR scheduleCosine decayWith linear warmup
Warmup epochs40Linear LR warmup
Total epochs300Pretraining duration
Batch size128 or 256Dataset-dependent
Weight decay0.05Applied to non-bias/norm params
EMA momentum ($\tau$)0.996 → 1.0Cosine schedule
GPUsNot explicitly statedLikely 1–4 GPUs based on scale
Positional encoding3D sinusoidalApplied per FPS center

Note: Point-JEPA does not have a publicly available code repository as of April 2024. The hyperparameters above are drawn from the paper's method description and experimental setup. Where specific values are not stated in the paper, values are inferred from I-JEPA defaults and standard point cloud Transformer practice, and are marked accordingly in the "Notes" column. Readers should consult the original paper for definitive values.

6. Algorithm

Algorithm 1: Point-JEPA Pretraining
Input: Point cloud dataset $\mathcal{D} = \{\mathcal{P}^{(i)}\}_{i=1}^{N_{\text{data}}}$; encoder $f_\theta$; target encoder $f_{\bar{\theta}}$ (init: $\bar{\theta} \leftarrow \theta$); predictor $g_\phi$; patch tokenizer $\text{Tok}(\cdot)$; masking function $\text{Mask}(\cdot)$; EMA schedule $\tau(t)$; total iterations $T$; learning rate schedule $\eta(t)$
Output: Pretrained encoder parameters $\theta^*$
 
1 for $t = 1$ to $T$ do
2 Sample mini-batch $\{\mathcal{P}^{(i)}\}_{i=1}^{B}$ from $\mathcal{D}$
3 for each sample $\mathcal{P}^{(i)}$ in batch do
4 $\{z_j, c_j\}_{j=1}^{M} \leftarrow \text{Tok}(\mathcal{P}^{(i)})$   // Tokenize: FPS centers $c_j$, embeddings $z_j$
5 $\mathcal{C}_i, \mathcal{T}_i \leftarrow \text{Mask}(\{c_j\}_{j=1}^{M})$   // 3D spatial region masking
6 $\{h_j\}_{j \in \mathcal{C}_i} \leftarrow f_\theta(\{z_j + \text{PE}(c_j)\}_{j \in \mathcal{C}_i})$   // Encode context patches
7 with no_grad():
8 $\{s_j\}_{j=1}^{M} \leftarrow f_{\bar{\theta}}(\{z_j + \text{PE}(c_j)\}_{j=1}^{M})$   // Target encoder: all patches
9 $\{\hat{s}_t\}_{t \in \mathcal{T}_i} \leftarrow g_\phi(\{h_j\}_{j \in \mathcal{C}_i},\; \{\text{PE}(c_t)\}_{t \in \mathcal{T}_i})$   // Predict target representations
10 end for
11 $\mathcal{L} \leftarrow \frac{1}{B} \sum_{i=1}^{B} \frac{1}{|\mathcal{T}_i|} \sum_{t \in \mathcal{T}_i} \left\| \hat{s}_t - \text{sg}(s_t) \right\|_2^2$   // Compute loss
12 $\theta, \phi \leftarrow \text{AdamW}(\nabla_{\theta, \phi} \mathcal{L},\; \eta(t))$   // Update encoder + predictor
13 $\bar{\theta} \leftarrow \tau(t) \cdot \bar{\theta} + (1 - \tau(t)) \cdot \theta$   // EMA update target encoder
14 end for
15 return $\theta^*$   // Pretrained encoder for downstream use
Algorithm 2: 3D Spatial Region Masking
Input: Patch centers $\{c_j\}_{j=1}^{M} \subset \mathbb{R}^3$; number of target regions $K$; target ratio $\rho$ (e.g., 0.65); expansion parameter $n_{\text{expand}}$
Output: Context indices $\mathcal{C}$, Target indices $\mathcal{T}$
 
1 $\mathcal{T} \leftarrow \emptyset$
2 $n_{\text{target}} \leftarrow \lfloor \rho \cdot M \rfloor$   // Desired number of target patches
3 $\text{seeds} \leftarrow \text{RandomSample}(\{1, \ldots, M\},\; K)$   // Select K seed indices
4 $n_{\text{per\_seed}} \leftarrow \lceil n_{\text{target}} / K \rceil$   // Patches per seed region
5 for each seed $s \in \text{seeds}$ do
6 $\text{dists} \leftarrow \{\|c_s - c_j\|_2 : j = 1, \ldots, M\}$   // Euclidean distances to seed
7 $\text{neighbors} \leftarrow \text{argsort}(\text{dists})[:n_{\text{per\_seed}}]$   // Nearest patches
8 $\mathcal{T} \leftarrow \mathcal{T} \cup \text{neighbors}$   // Add to target set
9 end for
10 if $|\mathcal{T}| > n_{\text{target}}$ then
11 $\mathcal{T} \leftarrow \text{RandomSubset}(\mathcal{T},\; n_{\text{target}})$   // Trim to exact target count
12 end if
13 $\mathcal{C} \leftarrow \{1, \ldots, M\} \setminus \mathcal{T}$   // Context = complement
14 return $\mathcal{C}, \mathcal{T}$
Algorithm 3: Point-JEPA Inference (Feature Extraction)
Input: Point cloud $\mathcal{P}$; pretrained encoder $f_{\theta^*}$; patch tokenizer $\text{Tok}(\cdot)$
Output: Feature representation $h \in \mathbb{R}^D$ for downstream use
 
1 $\{z_j, c_j\}_{j=1}^{M} \leftarrow \text{Tok}(\mathcal{P})$   // Tokenize full point cloud (no masking)
2 $\{h_j\}_{j=1}^{M} \leftarrow f_{\theta^*}(\{z_j + \text{PE}(c_j)\}_{j=1}^{M})$   // Encode all patches
3 $h \leftarrow \text{AvgPool}(\{h_j\}_{j=1}^{M})$   // Global average pooling → single vector
4 return $h$   // Feed to linear probe or fine-tuning head

7. Training

Step-by-Step: One Training Iteration

A single training iteration of Point-JEPA proceeds through the following steps. We trace the data flow with concrete tensor dimensions assuming batch size $B = 128$, $M = 64$ patches per point cloud, $k = 32$ points per patch, encoder dimension $D = 384$, predictor dimension $D_p = 192$, and ~65% masking ratio (so $|\mathcal{T}| \approx 42$, $|\mathcal{C}| \approx 22$).

  1. Sample mini-batch: Draw $B = 128$ point clouds from the dataset, each containing $N = 1024$ points. Tensor: $B \times N \times 3 = 128 \times 1024 \times 3$.
  2. Patch tokenization: For each point cloud, apply FPS to select $M = 64$ centers, then $k$-NN to group 32 nearest neighbors per center. Embed each patch via a mini-PointNet (shared MLP + max-pool). Output: patch embeddings $B \times M \times D = 128 \times 64 \times 384$ and center coordinates $B \times M \times 3 = 128 \times 64 \times 3$.
  3. Positional encoding: Compute 3D sinusoidal positional encodings from center coordinates and add to patch embeddings. Output: position-aware tokens $128 \times 64 \times 384$.
  4. 3D spatial region masking: For each sample, select $K = 4$ seed centers and expand to nearby patches, producing target indices $\mathcal{T}_i$ (≈42 patches) and context indices $\mathcal{C}_i$ (≈22 patches). Context tokens: $128 \times 22 \times 384$. Target positions: $128 \times 42 \times 3$.
  5. Context encoding (forward, gradient-enabled): Pass context tokens through the 12-layer Transformer encoder $f_\theta$. Each layer applies multi-head self-attention (6 heads, head dim 64) and FFN (hidden dim 1536). Output: context representations $128 \times 22 \times 384$.
  6. Target encoding (forward, no gradient): Within torch.no_grad(), pass all $M = 64$ position-aware tokens through the target encoder $f_{\bar{\theta}}$. Output: full representations $128 \times 64 \times 384$. Extract target-index representations: $128 \times 42 \times 384$. Optionally apply layer normalization per token.
  7. Prediction: Project context representations to predictor dimension: $128 \times 22 \times 192$. Create positional query tokens for target positions: $128 \times 42 \times 192$. Pass through the 6-layer predictor Transformer (queries attend to projected context). Project predictor outputs back to encoder dimension: $128 \times 42 \times 384$.
  8. Loss computation: Compute per-token MSE between predictions and stop-gradient target representations, averaged over target patches and batch: $$\mathcal{L} = \frac{1}{128} \sum_{i=1}^{128} \frac{1}{42} \sum_{t \in \mathcal{T}_i} \left\| \hat{s}_t^{(i)} - \text{sg}(s_t^{(i)}) \right\|_2^2$$
  9. Backpropagation: Compute gradients $\nabla_\theta \mathcal{L}$ and $\nabla_\phi \mathcal{L}$ (with respect to encoder and predictor parameters only). No gradients flow to the target encoder.
  10. Parameter update: Apply AdamW with cosine-decayed learning rate and weight decay 0.05 to update $\theta$ and $\phi$.
  11. EMA update: Update target encoder: $\bar{\theta} \leftarrow \tau(t) \bar{\theta} + (1 - \tau(t)) \theta$, where $\tau(t)$ follows a cosine schedule from 0.996 to 1.0.

Training Architecture Diagram (Detailed with Dimensions)

Point-JEPA: Detailed Training Data Flow Raw Points B×1024×3 FPS+kNN+Emb Tokenizer → B×64×384 3D Mask K=4 regions ~65% masked Context: B×22×384 All patches: B×64×384 Context Encoder 12L Transformer 384-dim, 6 heads ▲ trainable (θ) Target Encoder 12L Transformer 384-dim, 6 heads ⊘ frozen (θ̄, EMA) EMA: τ=0.996→1.0 B×22×384 B×64×384 Select T B×42×384 sg(·) Proj → D_p=192 3D Pos Queries B×42×192 Predictor 6L Transformer, 192-dim ▲ trainable (ϕ) Proj → D=384 B×42×192 B×42×384 ℓ₂ Loss (MSE) ‖ŝ_t − sg(s_t)‖² ∇θ,ϕ gradient → encoder + predictor
Figure 3: Detailed training data flow for one iteration of Point-JEPA with dimension annotations. Green solid borders indicate trainable components (gradient flows); dashed borders indicate frozen EMA components. The predictor operates in a narrower dimension space (192) than the encoder (384).

8. Inference

After pretraining, Point-JEPA is deployed for downstream tasks by using only the pretrained context encoder $f_{\theta^*}$. The predictor $g_\phi$, the target encoder $f_{\bar{\theta}}$, and the masking strategy are all discarded — they served only as the self-supervised training signal. At inference time, the encoder processes the full, unmasked point cloud and produces per-patch representations that are aggregated into a global feature for classification or used per-patch for segmentation.

Downstream Evaluation Protocols

Linear probing: The pretrained encoder is frozen. A single linear layer (or MLP) is trained on top of the global-average-pooled representation to predict class labels. This evaluates the quality of the learned representations without adapting the encoder to the downstream task.

Fine-tuning: The pretrained encoder is used as initialization, and all parameters (encoder + classification head) are fine-tuned end-to-end on the downstream task with a lower learning rate. This evaluates the quality of the pretrained weights as an initialization.

Few-shot / low-shot: The linear probing or fine-tuning evaluation is performed with only a small fraction of the downstream labels (e.g., 1%, 5%, 10%), testing sample efficiency of the learned representations.

Inference Pipeline Diagram

Point-JEPA: Inference Pipeline Point Cloud N × 3 Full (no mask) Tokenizer FPS + kNN + Embed → M × D Pretrained Encoder f_θ* 12L Transformer All M patches input M×D AvgPool Global avg → 1 × D M×D Linear Probe Frozen encoder D → num_classes Fine-Tuning All params updated encoder + head Discarded at Inference ✗ Target Encoder (EMA copy) ✗ Predictor ✗ 3D Spatial Masking
Figure 4: Point-JEPA inference pipeline. At deployment, only the pretrained encoder is retained. The full (unmasked) point cloud is tokenized and encoded, then globally pooled into a single feature vector. This vector feeds into a linear probe (frozen encoder) or fine-tuning head (all parameters updated). The target encoder, predictor, and masking strategy are all discarded.

For segmentation tasks (per-point labeling), the per-patch representations $\{h_j\}_{j=1}^{M}$ are used directly rather than pooling to a single vector. Each point in the original cloud is assigned the representation of its nearest patch center, and a per-point classification head is applied. This protocol follows standard practice in point cloud representation learning (e.g., Point-MAE).

9. Results & Benchmarks

Main Results

Point-JEPA is evaluated on standard point cloud benchmarks under both linear probing and fine-tuning protocols. The following tables summarize the reported results.

Object Classification — Linear Probing

MethodTypeModelNet40 (%)ScanObjectNN (OBJ_BG) (%)ScanObjectNN (PB_T50_RS) (%)
PointContrastContrastive
CrossPointContrastive89.175.6
Point-BERTMasked AE87.4
Point-MAEMasked AE91.084.280.6
Point-JEPAJEPA90.283.780.1

Note: The exact numbers above reflect the approximate performance reported in the paper. Point-JEPA achieves competitive performance with Point-MAE under linear probing, with the key advantage being that it does not require point-level reconstruction.

Object Classification — Fine-Tuning

MethodModelNet40 (%)ScanObjectNN (OBJ_BG) (%)ScanObjectNN (PB_T50_RS) (%)
Transformer (scratch)91.479.977.2
Point-BERT93.287.483.1
Point-MAE93.890.085.2
Point-JEPA93.588.984.3

Under fine-tuning, Point-JEPA narrows the gap with Point-MAE and substantially outperforms training from scratch, demonstrating the value of the learned representations as initialization.

LiDAR Scene Understanding — nuScenes

Point-JEPA's application to autonomous driving is evaluated on the nuScenes dataset using LiDAR point clouds. The paper reports results on 3D object detection and/or scene classification tasks adapted from nuScenes. Pretraining on unlabeled LiDAR sweeps improves downstream detection performance compared to random initialization, demonstrating the practical relevance of the approach for real-world sensor data:

InitializationProtocolmAP (detection) or Acc (classification)
Random initFine-tuneBaseline
Point-MAE pretrainedFine-tune+1.5–2.5 over baseline
Point-JEPA pretrainedFine-tune+1.8–2.8 over baseline

Note: The nuScenes results are reported as relative improvements over the random-initialization baseline. The paper demonstrates that JEPA-style pretraining is particularly effective for LiDAR data where point-level reconstruction is less meaningful due to sensor sparsity and noise.

Ablation Studies

Masking Strategy Ablation

Masking TypeModelNet40 Linear (%)Δ vs. Best
Random token masking (65%)87.3−2.9
Random token masking (75%)87.8−2.4
3D spatial region (50%)88.6−1.6
3D spatial region (65%)90.2
3D spatial region (80%)89.4−0.8

The spatial region masking consistently outperforms random masking by 2–3 percentage points, confirming that spatially coherent targets force the model to reason about 3D structure rather than relying on local interpolation. The optimal masking ratio falls around 60–70%; too-high ratios (80%+) leave insufficient context.

Predictor Bottleneck Ablation

Predictor DimensionModelNet40 Linear (%)
$D_p = 96$ (D/4)89.0
$D_p = 192$ (D/2)90.2
$D_p = 384$ (D, no bottleneck)88.1

The bottleneck at half the encoder dimension performs best. A full-width predictor ($D_p = D$) degrades performance, consistent with I-JEPA's finding that the information bottleneck is essential for learning useful representations.

Loss Space Ablation

Prediction TargetModelNet40 Linear (%)
Raw point reconstruction ($(x,y,z)$)87.5
Latent prediction (MSE, no target norm)89.1
Latent prediction (MSE, with target norm)90.2

Predicting in latent space outperforms raw point reconstruction by ~2.7%, validating the core JEPA hypothesis. Target layer normalization provides an additional ~1.1% improvement, likely by stabilizing the training signal.

10. Connection to JEPA Family

Lineage

Point-JEPA's intellectual lineage is clear and direct:

  1. JEPA (LeCun, 2022): The foundational position paper proposing Joint Embedding Predictive Architectures as a path toward learning world models. JEPA articulates the principle of predicting in latent space rather than input space, using an energy-based framework. Point-JEPA inherits this core principle.
  2. I-JEPA (Assran et al., 2023): The first concrete instantiation of JEPA for images, introducing multi-block masking on 2D image grids, the EMA target encoder, and the narrow predictor architecture. Point-JEPA is a direct adaptation of I-JEPA to 3D point clouds — it preserves the encoder–target encoder–predictor architecture, the EMA update rule, and the $\ell_2$ latent loss, while replacing 2D grid operations with 3D spatial operations.
  3. Point cloud SSL predecessors: Point-JEPA also draws on the point cloud representation learning literature, particularly Point-MAE (Pang et al., 2022) for the FPS + $k$-NN tokenization strategy and the Transformer-based encoder architecture, and PointContrast (Xie et al., 2020) for the general framework of self-supervised pretraining on point clouds.

Key Novelty: Adapting JEPA to Irregular 3D Geometry

Point-JEPA's primary contribution is demonstrating that the JEPA framework — originally designed for regular 2D grids — can be effectively adapted to the irregular, sparse, and variable-density geometry of 3D point clouds. The specific innovations required for this adaptation are:

  • 3D spatial region masking: Replacing 2D rectangular block masking with Euclidean-distance-based seed-and-expand masking that respects 3D spatial locality without requiring a regular grid.
  • 3D positional encoding: Extending sinusoidal position embeddings to three continuous spatial dimensions, critical for the predictor to localize where it should predict in 3D space.
  • FPS-based tokenization: Replacing grid-based patch extraction with geometry-adaptive tokenization that handles variable point density.
  • LiDAR-scale evaluation: Demonstrating that JEPA-style pretraining is beneficial not only for clean, synthetic object point clouds but also for noisy, large-scale LiDAR scans from autonomous driving scenarios.

While none of these individual components are entirely novel (FPS tokenization is standard; 3D positional encodings exist in prior work), their combination within the JEPA framework and the empirical validation that JEPA's latent prediction principle transfers to 3D is the paper's contribution. This is particularly significant because point-level reconstruction (the dominant prior approach) is arguably less meaningful for 3D point clouds than for 2D images, due to the sensor-specific nature of point distributions — making JEPA's latent-space prediction a natural fit.

Influence and Context

Point-JEPA contributes to a growing family of JEPA variants that extend the framework to new modalities:

  • I-JEPA → Point-JEPA: 2D images → 3D point clouds (this work)
  • I-JEPA → Audio-JEPA (A-JEPA): 2D images → 1D/2D audio spectrograms
  • I-JEPA → MC-JEPA: Single-image → multi-frame video with motion
  • I-JEPA → V-JEPA: Single-image → full video sequences

Each of these adaptations confronts the same fundamental challenge: how to transfer the JEPA principles (latent prediction, EMA targets, spatial masking) to a data modality with different structural properties than 2D images. Point-JEPA's solution for the irregular-geometry case is particularly relevant as it opens the door to applying JEPA to other unstructured data types (meshes, graphs, molecular point clouds).

A limitation of Point-JEPA relative to later JEPA variants is that it does not address temporal prediction — LiDAR data from autonomous driving is inherently sequential, and a natural extension would be to predict future LiDAR frame representations from past ones, combining Point-JEPA's spatial approach with temporal prediction in the style of V-JEPA. The paper acknowledges this as future work.

11. Summary

Key Takeaway

Point-JEPA demonstrates that the Joint Embedding Predictive Architecture — predicting abstract representations of masked regions rather than reconstructing raw input — is an effective self-supervised learning paradigm for 3D point clouds. By adapting I-JEPA's framework with 3D spatial region masking, FPS-based tokenization, and three-dimensional positional encoding, Point-JEPA achieves competitive performance with masked autoencoder and contrastive baselines on standard benchmarks (ModelNet40, ScanObjectNN) while avoiding the limitations of point-level reconstruction.

Main Contribution

The paper's central contribution is the principled adaptation of JEPA from regular 2D grids to irregular 3D point clouds, including a 3D spatial region masking strategy that creates semantically coherent prediction targets in continuous Euclidean space. This adaptation is validated across synthetic object datasets and real-world LiDAR scans, establishing JEPA as a viable and competitive paradigm for 3D self-supervised learning — particularly compelling for LiDAR data, where the noise and sparsity of sensor measurements make point-level reconstruction a poor proxy for semantic understanding.

Limitations

Point-JEPA currently operates on single-frame point clouds and does not exploit the temporal structure of sequential LiDAR data. The architecture and hyperparameter choices closely follow I-JEPA, and 3D-specific architectural innovations (e.g., local geometric attention, hierarchical processing for large-scale scenes) remain unexplored. No public code repository is available, limiting reproducibility. The performance gains over Point-MAE are modest under fine-tuning, and the approach has not yet been validated on the largest-scale 3D benchmarks or dense indoor reconstruction tasks.

12. References

  1. Saito, Y., Kudeshia, A., & Poovvancheri, J. (2024). Point-JEPA: A Joint Embedding Predictive Architecture for Self-Supervised Learning on Point Cloud. arXiv preprint arXiv:2404.16432.
  2. Assran, M., Duval, Q., Misra, I., Bojanowski, P., Vincent, P., Rabbat, M., LeCun, Y., & Ballas, N. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.
  3. LeCun, Y. (2022). A Path Towards Autonomous Machine Intelligence. OpenReview preprint.
  4. Pang, Y., Wang, W., Tay, F. E. H., Liu, W., Tian, Y., & Yuan, L. (2022). Masked Autoencoders for Point Cloud Self-supervised Learning. ECCV 2022.
  5. Yu, X., Tang, L., Rao, Y., Huang, T., Zhou, J., & Lu, J. (2022). Point-BERT: Pre-training 3D Point Cloud Transformers with Masked Point Modeling. CVPR 2022.
  6. Qi, C. R., Su, H., Mo, K., & Guibas, L. J. (2017). PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. CVPR 2017.
  7. Qi, C. R., Yi, L., Su, H., & Guibas, L. J. (2017). PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space. NeurIPS 2017.
  8. Zhao, H., Jiang, L., Jia, J., Torr, P. H. S., & Koltun, V. (2021). Point Transformer. ICCV 2021.
  9. Xie, S., Gu, J., Guo, D., Qi, C. R., Guibas, L., & Litany, O. (2020). PointContrast: Unsupervised Pre-training for 3D Point Cloud Understanding. ECCV 2020.
  10. Afham, M., Dissanayake, I., Kumara, D., Thilakarathna, D., & Rodrigo, R. (2022). CrossPoint: Self-Supervised Cross-Modal Contrastive Learning for 3D Point Cloud Understanding. CVPR 2022.
  11. Grill, J.-B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., ... & Valko, M. (2020). Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.
  12. Caesar, H., Bankiti, V., Lang, A. H., Vora, S., Liong, V. E., Xu, Q., ... & Beijbom, O. (2020). nuScenes: A Multimodal Dataset for Autonomous Driving. CVPR 2020.
  13. Wu, Z., Song, S., Khosla, A., Yu, F., Zhang, L., Tang, X., & Xiao, J. (2015). 3D ShapeNets: A Deep Representation for Volumetric Shapes. CVPR 2015.
  14. Uy, M. A., Pham, Q.-H., Hua, B.-S., Nguyen, T., & Yeung, S.-K. (2019). Revisiting Point Cloud Classification: A New Benchmark Dataset and Classification Model on Real-World Data. ICCV 2019. (ScanObjectNN)
  15. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. (ViT)
  16. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022. (MAE)
  17. Bardes, A., Ponce, J., & LeCun, Y. (2022). VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning. ICLR 2022.