Skip to content
/ VecKM Public

Official GitHub repo for VecKM. A very efficient and descriptive local geometry encoder / point tokenizer / patch embedder. ICML2024.

License

Notifications You must be signed in to change notification settings

dhyuan99/VecKM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

43 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VecKM: A Linear Time and Space Local Point Cloud Geometry Encoder

Dehao Yuan ,  Cornelia Fermüller ,  Tahseen Rabbani ,  Furong Huang ,  Yiannis Aloimonos  

ICML2024      [arXiv]

Highlighted Features

Usage

ℹ️ This section is illustrated with an examples/example.ipynb.

⚠️ VecKM is sensitive to scaling. Please make sure to scale your data so that your local point cloud lies within a UNIT BALL with radius 1.

⚠️ For example, if you have a point cloud pts and you want to consider the local geometry with radius 0.1. Then you will do pts *= 10 so that now you are considering the local geometry with radius 1.

⚠️ If your x, y, z do not have the same scale, make sure scaling them so that they have the same scale.

⚠️ VecKM is not rotational invariant. If the local point cloud is rotated, the encoding can be very different.

It is very simple to implement VecKM if you want to incorporate it into your own code. Suppose your input point cloud pts has shape (n,3) or (b,n,3), then the following code will give you the VecKM local geometry encoding with output shape (n,d) or (b,n,d). It is recommended to have PyTorch >= 1.13.0 since it has better support for complex tensors, but lower versions shall also work.

pip install scipy
pip install complexPyTorch
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import norm

def strict_standard_normal(d):
    # this function generate very similar outcomes as torch.randn(d)
    # but the numbers are strictly standard normal, no randomness.
    y = np.linspace(0, 1, d+2)
    x = norm.ppf(y)[1:-1]
    np.random.shuffle(x)
    x = torch.tensor(x).float()
    return x

class VecKM(nn.Module):
    def __init__(self, d=256, alpha=6, beta=1.8, p=4096):
        """ I tested empirically, here are some general suggestions for selecting parameters d and p:
        (alpha=6, beta=1.8) works for the data scale that your neighbordhood radius = 1.
        Please ensure your point cloud is appropriately scaled!
        d = 256, p = 4096 is for point cloud size ~20k. Runtime is about 28ms.
        d = 128, p = 8192 is for point cloud size ~50k. Runtime is about 76ms.
        For larger point cloud size, please enlarge p, but if that costs too much, please reduce d.
        A general empirical phenomenon is (d*p) is postively correlated with the encoding quality.

        For the selection of parameter alpha and beta, please see the github section below.
        """
        super().__init__()
        self.alpha, self.beta, self.d, self.p = alpha, beta, d, p
        self.sqrt_d = d ** 0.5

        self.A = torch.stack(
            [strict_standard_normal(d) for _ in range(3)], 
            dim=0
        ) * alpha
        self.A = nn.Parameter(self.A, False)                                    # (3, d)

        self.B = torch.stack(
            [strict_standard_normal(p) for _ in range(3)], 
            dim=0
        ) * beta
        self.B = nn.Parameter(self.B, False)                                    # (3, d)

    def forward(self, pts):
        """ Compute the dense local geometry encodings of the given point cloud.
        Args:
            pts: (bs, n, 3) or (n, 3) tensor, the input point cloud.

        Returns:
            G: (bs, n, d) or (n, d) tensor
               the dense local geometry encodings. 
               note: it is complex valued. 
        """
        pA = pts @ self.A                                                       # Real(..., n, d)
        pB = pts @ self.B                                                       # Real(..., n, p)
        eA = torch.concatenate((torch.cos(pA), torch.sin(pA)), dim=-1)          # Real(..., n, 2d)
        eB = torch.concatenate((torch.cos(pB), torch.sin(pB)), dim=-1)          # Real(..., n, 2p)
        G = torch.matmul(
            eB,                                                                 # Real(..., n, 2p)
            eB.transpose(-1,-2) @ eA                                            # Real(..., 2p, 2d)
        )                                                                       # Real(..., n, 2d)
        G = torch.complex(
            G[..., :self.d], G[..., self.d:]
        ) / torch.complex(
            eA[..., :self.d], eA[..., self.d:]
        )                                                                       # Complex(..., n, d)
        G = G / torch.norm(G, dim=-1, keepdim=True) * self.sqrt_d
        return G

vkm = VecKM()
pts = torch.rand((10,1000,3))
print(vkm(pts).shape) # it will be Complex(10,1000,256)
pts = torch.rand((1000,3))
print(vkm(pts).shape) # it will be Complex(1000, 256)

from complexPyTorch.complexLayers import ComplexLinear, ComplexReLU
# You may want to use apply two-layer feature transform to the encoding.
feat_trans = nn.Sequential(
    ComplexLinear(256, 128),
    ComplexReLU(),
    ComplexLinear(128, 128)
)
G = feat_trans(vkm(pts))
G = G.real**2 + G.imag**2 # it will be Real(10, 1000, 128) or Real(1000, 128).

ℹ️ See [Suggestion for Tuning $\alpha$, $\beta$] for how to tune alpha and beta parameters.

ℹ️ See [Suggestion for Tuning $d$, $p$] for how to tune d and p parameters.

ℹ️ Feel free to contact me if you are unsure! I will try to respond within 1 day.

Suggestions for picking $\alpha$ and $\beta$

There are two parameters alpha and beta in the VecKM encoding. They are controlling the resolution and receptive field of VecKM, respectively. A higher alpha will produce a more detailed encoding of the local geometry, and a smaller alpha will produce a more abstract encoding. A higher beta will result in a smaller receptive field. You could look at the figure below for a rough understanding.

  • You can slightly increase alpha if you have a relatively dense point cloud and want high-frequency details.
  • You can slightly decrease alpha if you want to smooth out the high-frequency details and only keep the low-frequency components.
  • For beta, it is closely related to the neighborhood radius. We provide a table of the correspondence. For example, if you want to extract the local geometry encoding with radius 0.3, then you would select beta to be 6.
beta 1 2 3 4 5 6 7 8 9 10
radius 1.800 0.900 0.600 0.450 0.360 0.300 0.257 0.225 0.200 0.180
beta 11 12 13 14 15 16 17 18 19 20
radius 0.163 0.150 0.138 0.129 0.120 0.113 0.106 0.100 0.095 0.090
beta 21 22 23 24 25 26 27 28 29 30
radius 0.086 0.082 0.078 0.075 0.072 0.069 0.067 0.065 0.062 0.060

Suggestion for picking $d$ and $p$

We find empirically $d\times p$ is strongly correlated to the encoding quality. Here are several tips:

  • A larger local neighborhood requires a larger $d$.
  • A larger point cloud size requires a larger $p$.

Several examples:

  • d = 256, p = 4096 is for point cloud size ~100k. Runtime is about 80ms.

Experiments

Check out the applications of VecKM to normal estimation, classification, part segmentation. The overall architecture change will be like:

Citation

If you find it helpful, please consider citing our papers:

@misc{yuan2024linear,
      title={A Linear Time and Space Local Point Cloud Geometry Encoder via Vectorized Kernel Mixture (VecKM)}, 
      author={Dehao Yuan and Cornelia Fermüller and Tahseen Rabbani and Furong Huang and Yiannis Aloimonos},
      year={2024},
      eprint={2404.01568},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}