PointNet Explained Visually
by Mariona Carós, PhD student at the University of Barcelona
PointNet is a deep net architecture that consumes point clouds for applications ranging from object classification, part segmentation, to scene semantic parsing. It was implemented in 2017 and was the first architecture that directly took point clouds as input for 3D recognition tasks.
The idea of this article is to implement the classification model of PointNet with Pytorch and visualize its transformations to understand how the model works.
In case you do not know what a point cloud is… It is just a 3D representation of an object or scene, typically collected from a LiDAR (Light Detection and Ranging) sensor. These sensors emit pulses of light and then measure the time it takes for them to return to the sensor. This information can be used to create a 3D model of the object or scene like the one above. LiDAR sensors are becoming more and more popular, you can find them in autonomous vehicles, drones, mapping airplanes or even in some smartphones!
Contents
- Dataset: Transform 2D MNIST images into 3D point clouds
- Architecture & Properties of point sets
- Visualizing T-Net transformation
- Visualizing critical points
- Takeaways
Dataset
For the sake of simplicity, we will be using the well-known MNIST dataset, that we can download directly with Pytorch.
MNIST contains 60,000 images of handwritten digits, from 0 to 9.
PointNet deals with points represented by three coordinates (x, y, z), so we are going to transform the 2D images into 3D point clouds like this one below.
MNIST samples are grayscale images of 28 x 28 pixels. Pixel values are integers that range from 0 (black) to 255 (white). We want to transform each of the number’s pixels into a point. The function transform_img2pc
filters the pixels of the image that have a value higher than 127 and get their indices.
Once we have converted pixels into points, we need all point clouds to have the same number of points so we can input them into PointNet. The authors of PointNet use 2500 points per object, we are going to plot the histogram of points per number to decide a threshold.
We fix the number of points at 200 since the maximum number of points is 312 and most of the points are below 200. We may face two cases, point clouds above 200 points and point clouds below this threshold.
- When the number of points is above 200 we are going to randomly sample the points.
- On the contrary, we are going to randomly duplicate existing points.
Finally, we are going to add a 3rd component z to all the points generating Gaussian noise with zero mean and 0.05 standard deviation.
Let’s wrap up the data processing in a custom Dataset
class.
The Dataset
stores the preprocessed samples and their corresponding labels, now we need to define a DataLoader
to iterate through the data in the training loop.
Once the MNIST data is downloaded, we concatenate the default partitions (train and test) and input the data into our custom MNIST3D Dataset. Then, we split our dataset into train (80%) validation(10%) and test(10%) and generate a DataLoader
for each partition with a batch size of 128.
Finally, we plot some samples to check that our point clouds are being generated properly. You can also generate a cool 3D gif like the one above by using the implementation of our notebook.
Now that data is ready, we can focus on the model!
Architecture and Properties of Point Sets
PointNet is composed of a classification network and a segmentation network. The classification network takes n points (x, y, z) as input, applies input and feature transformations by using T-Net, and then aggregates point features by max pooling. The output is a classification score for each of the k classes. The segmentation network is an extension to the classification net. It concatenates global and local features and outputs per point scores.
The architecture of pointNet is inspired by the properties of point sets, they are key to some of the design choices… Let’s check them!
- UNORDERED. Unlike pixel arrays in images, point cloud is a set of points without specific order.
- Requirement: The model needs to be invariant to permutations of points.
- Solution: Use max pooling layer as a symmetric function to aggregate information from all the points. Max pooling, like * and +, are symmetric functions because the order of the inputs does not alter the result.
2. INTERACTION AMONG POINTS. The points are from a space with a distance metric. It means that points are not isolated, and neighboring points form a meaningful subset.
- Requirement: The model needs to be able to capture local structures from nearby points.
- Solution: Combine local and global features for segmentation.
3. INVARIANCE UNDER TRANSFORMATIONS. The learned representation of the point set should be invariant to certain transformations.
- Requirement: Rotating and translating points all together should not modify the global point cloud category nor the segmentation of the points.
- Solution: Use a spatial transformer network that attempts to transform the data into a canonical form before the PointNet processes them. T-Net is a neural network used to align both input points and point features.
You can see the use of T-Net (input transform
and feature_transform
), max-pooling (MaxPool1d
) and features generation (local and global) in the code below. ClassificationPointNet
returns the log probabilities per point cloud, the feature transform
needed for the loss regularization, and the last two elements (tnet_out, ix_maxpool
) for plot purposes.
In the next section, we will go into more detail about the implementation of T-Net, how it works and what benefits provides.
__init__ functions have been omitted for space purposes but you can check them in the notebook.
Training PointNet
We use a classic Pytorch training loop to train our model. We set the learning rate at 0.001 and maximum number of epochs at 80. You can find a lighter version of PointNet in the link above (implemented in Google Colab) to play with it. PointNet contains several MLPs, as a result it has a large number of trainable parameters (3.472.339). A lighter version of PointNet is implemented by decreasing the number of neurons per layer to reduce the training time, resulting in 910.611 trainable parameters.
The model is optimized with the negative log likelihood loss (NLL) with a regularization term to make it more stable. The NLL is the typical loss when training a classification problem with several classes.
Once we see the loss has converged, so that validation loss does not decrease, we can stop training and test our model.
Test Accuracy
0.967
Alert⚠️ If model is not completely trained it may not guarantee the property of invariance to permutations.
Visualizing input and output of T-Net
T-Net aligns all input set to a canonical space before feature extraction. How does it do it? It predicts an affine transformation matrix of 3x3 to be applied to the coordinate of input points (x, y, z).
This idea can be further extended to the alignment of feature space. You can see in the PointNet architecture figure that the second T-Net predicts a feature transformation matrix of 64x64, which is used to align features from different input point clouds.
As you can see in the following block of code, T-Net is composed by 1D convolutional layers for point independent feature extraction, max pooling and fully connected layers. The result is a transformation matrix that we directly apply to the coordinates of input points.
Note📝 T-Net aligns all input set to a canonical space by learning a transformation matrix
By plotting the result of multiplying the T-Net output by the input points we can see the canonical transformation that is performing to the input point cloud.
One of the properties of PointNet is that it is invariant to permutations of points. Let’s test it! We are going to shuffle points and compare both transformations and predictions. We will make the point size smaller to better identify the differences between both transformations.
We can see that for this example, with a different order of points we get a very similar representation and same prediction.
Will it be preserved for all test samples? Let’s compare the predictions between shuffled and non shuffled points on all test samples.
(results==results_shuffle)
False
We get 6 samples out of 7.000 (test set size) with different results when shuffling points. We store the indices of those samples to compare both transformations and predictions. Here you can see a couple of examples:
We see that the transformation is pretty similar and that we also could be wrong when guessing those numbers by looking at the T-Net transformation. Why do you think the same model is predicting different numbers? We could plot the points that contributed to the max pooling to get an idea.
Visualizing PointNet Critical points
PointNet learns to summarize an input point cloud by a sparse set of key points that authors call critical points. The critical points are those that contributed to the max pooled feature.
We stored the indices of the max-pooling layer, we plot those points for shuffled and non-shuffled point clouds and obtain the following figures:
We see that the critical point set corresponds to the skeleton of the number and are different between shuffled and non-shuffled point clouds, this causes the model to predict one class or another!
Note📝 One way of improving invariance to permutations could be by training the model with shuffling of points.
Takeaways
We have played with PointNet (notebook here), a neural network that directly consumes point clouds which is the best representation of data when dealing with 3D objects and scenes. We have seen:
- PointNet requires a fixed number of points per point cloud that we solved with sampling and generating synthetic points.
- T-Net aligns both input points and point features by predicting a transformation matrix directly applied to the data.
- The max pooling layer as a symmetric function to aggregate information from all the points, can be used for explainability of results.
If you’d like to see another applied example of PointNet with Pytorch I recommend you this article.
Thank you for following the tutorial! The goal of this work was to help you better understand PointNet in a friendly way. Even though there is room for improvement of results, I believe it has been useful. If you like it, please clap and any comment is welcome!