# PETR: Position Embedding Transformation for Multi-View 3D Object Detection

By [Perception](https://paragraph.com/@perception) · 2025-05-02

---

### Problem

DETR3D, and DETR struggle with coordinate prediction and feature sampling complexity.

![](https://storage.googleapis.com/papyrus_images/ad9a951f3bc35f3a3e571c0568b72a69.png)

(a) In DETR, the object queries interact with 2D features to perform 2D detection. (b) DETR3D repeatedly projects the generated 3D reference points into image plane and samples the 2D features to interact with object queries in decoder. (c) PETR generates the 3D position-aware features by encoding the 3D position embed-ding (3D PE) into 2D image features. The object queries directly interact with 3D position-aware features and output 3D detection results.

The PETR framework is shown below.

![](https://storage.googleapis.com/papyrus_images/9a758faaec21b47a4e316ecd18adb426.png)

Steps:

1.  Discretizing the camera frustum space into meshgrid coordinates of size $$(W\_F, H\_F, D)$$.
    
2.  Transform the coordinates into 3D world space using the camera parameters.
    
3.  3D position-aware features are extracted by combining the 2D image features and 3D world space using the **3D position Encoder**.
    
4.  Queries interact with 3D aware features to predict 3D bbox. The queries are obtained by transforming the anchor points in 3D coordinates using MLP (query generator).
    

### Implicit Neural Representation (INR)

maps the coordinates to visual cues by a MLP to efficiently model 3D objects, 3D scenes and 2D images.

*   The queries are obtained by transforming the anchor points in 3D coordinates using MLP.
    

### Method

1.  Input images $$I = \\{I\_i \\in \\mathcal{R}^{3\\times H\_I \\times W\_I}, i = 1, 2, \\dots, N\\}$$ from $$N$$ views.
    
2.  Backbone Network: extracts 2D multi-view features $$F^{2d} = F\_i^{2d} \\in \\{\\mathcal{R}^{C \\times H\_F \\times W\_F}, i=1,2,\\dots, N\\}$$
    
3.  Discretizing the camera frustum space into meshgrid coordinates of size $$(W\_F, H\_F, D)$$. Each point in meshgrid is denoted as $$p^m\_j = (u\_j \\times d\_j, v\_j \\times d\_j, d\_j, 1)^T$$, where $$(u\_j, v\_j)$$ in camera coordinates. j denotes the j-th meshgrid.
    
    ![](https://storage.googleapis.com/papyrus_images/e7803971a4458ffa0cbafe7e9037ac39.png)
    
4.  Generate 3D world space coordinates by transforming meshgrid points using camera parameters. The meshgrid is shared by multiple views, the 3D world space is
    
    ![](https://storage.googleapis.com/papyrus_images/98d9cfdd51f64a0960abf756adeddb48.png)
    
    The principal point $$(c\_u,c\_v)$$ are $$(o\_x, o\_y)$$ where the optical axis pierces the sensor.
    
    $$p\_{i,j}^{3d} = K\_i^{-1}p\_j^m, ~~p\_{i,j}^{3d} = (x\_{i,j}, y\_{i,j}, z\_{i,j}, 1)^T$$. $$i-th$$ view.
    
    ![](https://storage.googleapis.com/papyrus_images/f8740918aec336a80ac001d017a782e3.png)
    
5.  Since there are $$D$$ depths in the meshgrid, the normalized coordinates of size $$H\_F \\times W\_F \\times D$$ are denoted as , $$P^{3d} = P\_i^{3d} \\in \\{\\mathcal{R}^{ (D\\times 4) \\times H\_F \\times W\_F}, i=1,2,\\dots, N\\}$$
    
6.  **3D position encoder** produces 3D position aware features, i.e., $$F^{3d} = F\_i^{3d} \\in \\{\\mathcal{R}^{C \\times H\_F \\times W\_F}, i=1,2,\\dots, N\\}$$
    
    ![](https://storage.googleapis.com/papyrus_images/3498314c9cb4a30337f5bfc208163eed.png)
    
    3D position encoder architecture $$F^{3d}\_i = \\psi(F\_i^{2d}, P\_i^{3d}), i = 1, 2, \\dots, N$$
    
7.  **Query Generator and Decoder**: a set of learnable anchor points, in 3D space, initialized with a uniform distribution are fed to MLPs (two linear layers) to obtain initial object queries ($$Q\_0$$). useful for convergence.
    
8.  **Decoder**: $$\\mathcal{Q}\_l = \\Omega\_l(F^{3d}, \\mathcal{Q}\_{l-1}), l = 1, \\dots, L$$
    

### Sample Code

Positional Embedding (3D)

    import torch
    import torch.nn as nn
    import numpy as np
    
    class PositionEmbedding3D(nn.Module):
        def __init__(self, embed_dim=256, depth_bins=torch.tensor([1., 20, 64])):
            super().__init__()
            self.embed_dim = embed_dim
            self.depth_bins = depth_bins
            
            # MLP to encode (X, Y, Z) into embedding space
            self.mlp = nn.Sequential(
                nn.Linear(3, embed_dim // 2),
                nn.ReLU(),
                nn.Linear(embed_dim // 2, embed_dim)
            
            # Precompute depth bins (can also be learned)
            self.register_buffer('depth_bins', torch.linspace(1, 100, depth_bins))
    
        def forward(self, H, W, device):
            # 1. Generate 2D pixel coordinates (u, v)
            u = torch.arange(W, dtype=torch.float32, device=device)  # [W]
            v = torch.arange(H, dtype=torch.float32, device=device)  # [H]
            u, v = torch.meshgrid(u, v)  # [H, W]
            
            # Normalize coordinates (optional, e.g., to [-1, 1])
            u = (u / W) * 2 - 1  # [H, W]
            v = (v / H) * 2 - 1  # [H, W]
            
            # 2. Generate 3D points for all depth bins
            # Shape: [H, W, depth_bins, 3]
            X = u.unsqueeze(-1) * self.depth_bins  # [H, W, depth_bins]
            Y = v.unsqueeze(-1) * self.depth_bins  # [H, W, depth_bins]
            Z = self.depth_bins.reshape(1, 1, -1).repeat(H, W, 1)  # [H, W, depth_bins]
            xyz = torch.stack([X, Y, Z], dim=-1)  # [H, W, depth_bins, 3]
            
            # 3. Project 3D points into embedding space
            xyz = xyz.flatten(0, 2)  # [H*W*depth_bins, 3]
            embedding = self.mlp(xyz)  # [H*W*depth_bins, embed_dim]
            embedding = embedding.view(H, W, self.depth_bins, self.embed_dim)  # [H, W, depth_bins, embed_dim]
            
            return embedding

Query Generator

    class PETRQueryGenerator(nn.Module):
        def __init__(self, num_queries=900, embed_dim=256):
            super().__init__()
            # Learnable content queries
            self.content_queries = nn.Parameter(torch.randn(num_queries, embed_dim))
            # 3D reference points (initialized as normalized 3D coordinates)
            self.ref_points = nn.Parameter(torch.rand(num_queries, 3))  # [x, y, z]
    
        def forward(self, camera_params):
            # Project 3D ref points to 2D for each camera view
            projected_2d = project_3d_to_2d(self.ref_points, camera_params)  # [N_views, N_queries, 2]
            # Generate 3D positional embeddings
            pos_embed = self.mlp(projected_2d)  # [N_views, N_queries, C]
            return self.content_queries + pos_embed.mean(dim=0)  # Fused queries

  

9.  **Detection Head and Loss in PETR**
    

**(a) Detection Head**

The **detection head** in PETR consists of two parallel branches:

*   **Classification Branch**: Predicts the probability distribution over object classes (e.g., car, pedestrian, cyclist).
    
*   **Regression Branch**: Predicts the **7-DoF 3D bounding box** parameters:
    
    *   **Center** ((x, y, z)) in world coordinates.
        
    *   **Dimensions** ((w, h, l)) (width, height, length).
        
    *   **Rotation** ((\\theta)) around the vertical axis (yaw).
        

Each **object query** from the decoder outputs:

*   A class score vector (via softmax).
    
*   A 7-DoF box regression vector (unbounded, real-valued).
    

**(b) Objective Function**

PETR optimizes detection using a **multi-task loss**:

$$\\mathcal{L} = \\lambda\_{cls} \\cdot \\mathcal{L}{cls} + \\lambda{reg} \\cdot \\mathcal{L}\_{reg} $$

*   **Classification Loss ((\\mathcal{L}\_{cls}))**:
    
    *   **Focal Loss** (variant of cross-entropy for class imbalance):  
        $$ \\mathcal{L}\_{cls} = -\\alpha\_t (1 - p\_t)^\\gamma \\log(p\_t) $$
        
        *   $$(p\_t)$$: Predicted probability for the target class.
            
        *   ($$\\alpha\_t, \\gamma$$): Hyperparameters to balance easy/hard samples.
            
*   **Regression Loss ((\\mathcal{L}\_{reg}))**:
    
    *   **L1 Loss** (for box coordinates):  
        \[ $$\\mathcal{L}_{reg} = | \\mathbf{b}_{pred} - \\mathbf{b}\_{gt} |\_1 $$\]
        
        *   ($$\\mathbf{b}\_{pred}$$): Predicted box parameters.
            
        *   ($$\\mathbf{b}\_{gt}$$): Ground-truth box.
            
    *   **Optional**: Smooth L1 or GIoU loss for rotation stability.
        
*   **Balancing Weights ((**$$**\\lambda\_{cls}, \\lambda\_{reg}**$$**))**:
    
    *   Typically ($$\\lambda\_{cls} = 1.0), (\\lambda\_{reg} = 5.0$$) (empirically tuned).
        

  

### References

1.  PETR: [https://arxiv.org/abs/2203.05625](https://arxiv.org/abs/2203.05625)
    
2.  Medium Link: [https://medium.com/@jiangmen28/petr-position-embedding-transformation-for-multi-view-3d-object-detection-70cbeb5c3701](https://medium.com/@jiangmen28/petr-position-embedding-transformation-for-multi-view-3d-object-detection-70cbeb5c3701)

---

*Originally published on [Perception](https://paragraph.com/@perception/petr-position-embedding-transformation-for-multi-view-3d-object-detection)*
