diff --git a/configs/_base_/datasets/semantickitti.py b/configs/_base_/datasets/semantickitti.py
index 9ad3e1e78a..989e267e18 100644
--- a/configs/_base_/datasets/semantickitti.py
+++ b/configs/_base_/datasets/semantickitti.py
@@ -1,77 +1,53 @@
-# dataset settings
-dataset_type = 'SemanticKITTIDataset'
+# For SemanticKitti we usually do 19-class segmentation.
+# For labels_map we follow the uniform format of MMDetection & MMSegmentation
+# i.e. we consider the unlabeled class as the last one, which is different
+# from the original implementation of some methods e.g. Cylinder3D.
+dataset_type = 'SemanticKittiDataset'
 data_root = 'data/semantickitti/'
 class_names = [
-    'unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
-    'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground',
-    'building', 'fence', 'vegetation', 'trunck', 'terrian', 'pole',
-    'traffic-sign'
+    'car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person', 'bicyclist',
+    'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 'building',
+    'fence', 'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'
 ]
-palette = [
-    [174, 199, 232],
-    [152, 223, 138],
-    [31, 119, 180],
-    [255, 187, 120],
-    [188, 189, 34],
-    [140, 86, 75],
-    [255, 152, 150],
-    [214, 39, 40],
-    [197, 176, 213],
-    [148, 103, 189],
-    [196, 156, 148],
-    [23, 190, 207],
-    [247, 182, 210],
-    [219, 219, 141],
-    [255, 127, 14],
-    [158, 218, 229],
-    [44, 160, 44],
-    [112, 128, 144],
-    [227, 119, 194],
-    [82, 84, 163],
-]
-
 labels_map = {
-    0: 0,  # "unlabeled"
-    1: 0,  # "outlier" mapped to "unlabeled" --------------mapped
-    10: 1,  # "car"
-    11: 2,  # "bicycle"
-    13: 5,  # "bus" mapped to "other-vehicle" --------------mapped
-    15: 3,  # "motorcycle"
-    16: 5,  # "on-rails" mapped to "other-vehicle" ---------mapped
-    18: 4,  # "truck"
-    20: 5,  # "other-vehicle"
-    30: 6,  # "person"
-    31: 7,  # "bicyclist"
-    32: 8,  # "motorcyclist"
-    40: 9,  # "road"
-    44: 10,  # "parking"
-    48: 11,  # "sidewalk"
-    49: 12,  # "other-ground"
-    50: 13,  # "building"
-    51: 14,  # "fence"
-    52: 0,  # "other-structure" mapped to "unlabeled" ------mapped
-    60: 9,  # "lane-marking" to "road" ---------------------mapped
-    70: 15,  # "vegetation"
-    71: 16,  # "trunk"
-    72: 17,  # "terrain"
-    80: 18,  # "pole"
-    81: 19,  # "traffic-sign"
-    99: 0,  # "other-object" to "unlabeled" ----------------mapped
-    252: 1,  # "moving-car" to "car" ------------------------mapped
-    253: 7,  # "moving-bicyclist" to "bicyclist" ------------mapped
-    254: 6,  # "moving-person" to "person" ------------------mapped
-    255: 8,  # "moving-motorcyclist" to "motorcyclist" ------mapped
-    256: 5,  # "moving-on-rails" mapped to "other-vehic------mapped
-    257: 5,  # "moving-bus" mapped to "other-vehicle" -------mapped
-    258: 4,  # "moving-truck" to "truck" --------------------mapped
-    259: 5  # "moving-other"-vehicle to "other-vehicle"-----mapped
+    0: 19,  # "unlabeled"
+    1: 19,  # "outlier" mapped to "unlabeled" --------------mapped
+    10: 0,  # "car"
+    11: 1,  # "bicycle"
+    13: 4,  # "bus" mapped to "other-vehicle" --------------mapped
+    15: 2,  # "motorcycle"
+    16: 4,  # "on-rails" mapped to "other-vehicle" ---------mapped
+    18: 3,  # "truck"
+    20: 4,  # "other-vehicle"
+    30: 5,  # "person"
+    31: 6,  # "bicyclist"
+    32: 7,  # "motorcyclist"
+    40: 8,  # "road"
+    44: 9,  # "parking"
+    48: 10,  # "sidewalk"
+    49: 11,  # "other-ground"
+    50: 12,  # "building"
+    51: 13,  # "fence"
+    52: 19,  # "other-structure" mapped to "unlabeled" ------mapped
+    60: 8,  # "lane-marking" to "road" ---------------------mapped
+    70: 14,  # "vegetation"
+    71: 15,  # "trunk"
+    72: 16,  # "terrain"
+    80: 17,  # "pole"
+    81: 18,  # "traffic-sign"
+    99: 19,  # "other-object" to "unlabeled" ----------------mapped
+    252: 0,  # "moving-car" to "car" ------------------------mapped
+    253: 6,  # "moving-bicyclist" to "bicyclist" ------------mapped
+    254: 5,  # "moving-person" to "person" ------------------mapped
+    255: 7,  # "moving-motorcyclist" to "motorcyclist" ------mapped
+    256: 4,  # "moving-on-rails" mapped to "other-vehic------mapped
+    257: 4,  # "moving-bus" mapped to "other-vehicle" -------mapped
+    258: 3,  # "moving-truck" to "truck" --------------------mapped
+    259: 4  # "moving-other"-vehicle to "other-vehicle"-----mapped
 }
 
 metainfo = dict(
-    classes=class_names,
-    palette=palette,
-    seg_label_mapping=labels_map,
-    max_label=259)
+    classes=class_names, seg_label_mapping=labels_map, max_label=259)
 
 input_modality = dict(use_lidar=True, use_camera=False)
 
@@ -99,7 +75,10 @@
         backend_args=backend_args),
     dict(
         type='LoadAnnotations3D',
+        with_bbox_3d=False,
+        with_label_3d=False,
         with_seg_3d=True,
+        seg_3d_dtype='np.int32',
         seg_offset=2**16,
         dataset_type='semantickitti',
         backend_args=backend_args),
@@ -126,7 +105,10 @@
         backend_args=backend_args),
     dict(
         type='LoadAnnotations3D',
+        with_bbox_3d=False,
+        with_label_3d=False,
         with_seg_3d=True,
+        seg_3d_dtype='np.int32',
         seg_offset=2**16,
         dataset_type='semantickitti',
         backend_args=backend_args),
@@ -144,7 +126,10 @@
         backend_args=backend_args),
     dict(
         type='LoadAnnotations3D',
+        with_bbox_3d=False,
+        with_label_3d=False,
         with_seg_3d=True,
+        seg_3d_dtype='np.int32',
         seg_offset=2**16,
         dataset_type='semantickitti',
         backend_args=backend_args),
@@ -153,7 +138,7 @@
 ]
 
 train_dataloader = dict(
-    batch_size=4,
+    batch_size=2,
     num_workers=4,
     sampler=dict(type='DefaultSampler', shuffle=True),
     dataset=dict(
@@ -162,10 +147,11 @@
         dataset=dict(
             type=dataset_type,
             data_root=data_root,
-            ann_file='train_infos.pkl',
+            ann_file='semantickitti_infos_train.pkl',
             pipeline=train_pipeline,
             metainfo=metainfo,
             modality=input_modality,
+            ignore_index=19,
             backend_args=backend_args)),
 )
 
@@ -179,10 +165,11 @@
         dataset=dict(
             type=dataset_type,
             data_root=data_root,
-            ann_file='valid_infos.pkl',
+            ann_file='semantickitti_infos_val.pkl',
             pipeline=test_pipeline,
             metainfo=metainfo,
             modality=input_modality,
+            ignore_index=19,
             test_mode=True,
             backend_args=backend_args)),
 )
@@ -191,3 +178,7 @@
 
 val_evaluator = dict(type='SegMetric')
 test_evaluator = val_evaluator
+
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+    type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
diff --git a/mmdet3d/datasets/__init__.py b/mmdet3d/datasets/__init__.py
index e1153ba891..d573ca4ed9 100644
--- a/mmdet3d/datasets/__init__.py
+++ b/mmdet3d/datasets/__init__.py
@@ -9,7 +9,7 @@
 from .scannet_dataset import (ScanNetDataset, ScanNetInstanceSegDataset,
                               ScanNetSegDataset)
 from .seg3d_dataset import Seg3DDataset
-from .semantickitti_dataset import SemanticKITTIDataset
+from .semantickitti_dataset import SemanticKittiDataset
 from .sunrgbd_dataset import SUNRGBDDataset
 # yapf: disable
 from .transforms import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
@@ -33,7 +33,7 @@
     'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
     'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset',
     'ScanNetDataset', 'ScanNetSegDataset', 'ScanNetInstanceSegDataset',
-    'SemanticKITTIDataset', 'Det3DDataset', 'Seg3DDataset',
+    'SemanticKittiDataset', 'Det3DDataset', 'Seg3DDataset',
     'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
     'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor',
     'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize',
diff --git a/mmdet3d/datasets/seg3d_dataset.py b/mmdet3d/datasets/seg3d_dataset.py
index 3f30fb6ccb..803a1b4d2f 100644
--- a/mmdet3d/datasets/seg3d_dataset.py
+++ b/mmdet3d/datasets/seg3d_dataset.py
@@ -256,6 +256,9 @@ def parse_data_info(self, info: dict) -> dict:
                     self.data_prefix.get('pts', ''),
                     info['lidar_points']['lidar_path'])
 
+            info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
+            info['lidar_path'] = info['lidar_points']['lidar_path']
+
         if self.modality['use_camera']:
             for cam_id, img_info in info['images'].items():
                 if 'img_path' in img_info:
diff --git a/mmdet3d/datasets/semantickitti_dataset.py b/mmdet3d/datasets/semantickitti_dataset.py
index c157208c6c..134333f778 100644
--- a/mmdet3d/datasets/semantickitti_dataset.py
+++ b/mmdet3d/datasets/semantickitti_dataset.py
@@ -8,8 +8,8 @@
 
 
 @DATASETS.register_module()
-class SemanticKITTIDataset(Seg3DDataset):
-    r"""SemanticKITTI Dataset.
+class SemanticKittiDataset(Seg3DDataset):
+    r"""SemanticKitti Dataset.
 
     This class serves as the API for experiments on the SemanticKITTI Dataset
     Please refer to <http://www.semantic-kitti.org/dataset.html>`_
@@ -44,14 +44,20 @@ class SemanticKITTIDataset(Seg3DDataset):
             Defaults to False.
     """
     METAINFO = {
-        'classes': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck',
-                    'bus', 'person', 'bicyclist', 'motorcyclist', 'road',
-                    'parking', 'sidewalk', 'other-ground', 'building', 'fence',
-                    'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'),
+        'classes': ('car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
+                    'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk',
+                    'other-ground', 'building', 'fence', 'vegetation',
+                    'trunck', 'terrian', 'pole', 'traffic-sign'),
+        'palette': [[100, 150, 245], [100, 230, 245], [30, 60, 150],
+                    [80, 30, 180], [100, 80, 250], [155, 30, 30],
+                    [255, 40, 200], [150, 30, 90], [255, 0, 255],
+                    [255, 150, 255], [75, 0, 75], [175, 0, 75], [255, 200, 0],
+                    [255, 120, 50], [0, 175, 0], [135, 60, 0], [150, 240, 80],
+                    [255, 240, 150], [255, 0, 0]],
         'seg_valid_class_ids':
-        tuple(range(20)),
+        tuple(range(19)),
         'seg_all_class_ids':
-        tuple(range(20)),
+        tuple(range(19)),
     }
 
     def __init__(self,
@@ -59,7 +65,7 @@ def __init__(self,
                  ann_file: str = '',
                  metainfo: Optional[dict] = None,
                  data_prefix: dict = dict(
-                     pts='points',
+                     pts='',
                      img='',
                      pts_instance_mask='',
                      pts_semantic_mask=''),
@@ -83,7 +89,7 @@ def __init__(self,
             **kwargs)
 
     def get_seg_label_mapping(self, metainfo):
-        seg_label_mapping = np.zeros(metainfo['max_label'] + 1)
+        seg_label_mapping = np.zeros(metainfo['max_label'] + 1, dtype=np.int64)
         for idx in metainfo['seg_label_mapping']:
             seg_label_mapping[idx] = metainfo['seg_label_mapping'][idx]
         return seg_label_mapping
diff --git a/tests/test_datasets/test_semantickitti_dataset.py b/tests/test_datasets/test_semantickitti_dataset.py
index 300253a681..d334870da5 100644
--- a/tests/test_datasets/test_semantickitti_dataset.py
+++ b/tests/test_datasets/test_semantickitti_dataset.py
@@ -3,75 +3,53 @@
 
 import numpy as np
 
-from mmdet3d.datasets import SemanticKITTIDataset
+from mmdet3d.datasets import SemanticKittiDataset
 from mmdet3d.utils import register_all_modules
 
 
 def _generate_semantickitti_dataset_config():
     data_root = './tests/data/semantickitti/'
     ann_file = 'semantickitti_infos.pkl'
-    classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
-               'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
-               'sidewalk', 'other-ground', 'building', 'fence', 'vegetation',
-               'trunck', 'terrian', 'pole', 'traffic-sign')
-    palette = [
-        [174, 199, 232],
-        [152, 223, 138],
-        [31, 119, 180],
-        [255, 187, 120],
-        [188, 189, 34],
-        [140, 86, 75],
-        [255, 152, 150],
-        [214, 39, 40],
-        [197, 176, 213],
-        [148, 103, 189],
-        [196, 156, 148],
-        [23, 190, 207],
-        [247, 182, 210],
-        [219, 219, 141],
-        [255, 127, 14],
-        [158, 218, 229],
-        [44, 160, 44],
-        [112, 128, 144],
-        [227, 119, 194],
-        [82, 84, 163],
-    ]
+    classes = ('car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
+               'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk',
+               'other-ground', 'building', 'fence', 'vegetation', 'trunck',
+               'terrian', 'pole', 'traffic-sign')
 
     seg_label_mapping = {
-        0: 0,  # "unlabeled"
-        1: 0,  # "outlier" mapped to "unlabeled" --------------mapped
-        10: 1,  # "car"
-        11: 2,  # "bicycle"
-        13: 5,  # "bus" mapped to "other-vehicle" --------------mapped
-        15: 3,  # "motorcycle"
-        16: 5,  # "on-rails" mapped to "other-vehicle" ---------mapped
-        18: 4,  # "truck"
-        20: 5,  # "other-vehicle"
-        30: 6,  # "person"
-        31: 7,  # "bicyclist"
-        32: 8,  # "motorcyclist"
-        40: 9,  # "road"
-        44: 10,  # "parking"
-        48: 11,  # "sidewalk"
-        49: 12,  # "other-ground"
-        50: 13,  # "building"
-        51: 14,  # "fence"
-        52: 0,  # "other-structure" mapped to "unlabeled" ------mapped
-        60: 9,  # "lane-marking" to "road" ---------------------mapped
-        70: 15,  # "vegetation"
-        71: 16,  # "trunk"
-        72: 17,  # "terrain"
-        80: 18,  # "pole"
-        81: 19,  # "traffic-sign"
-        99: 0,  # "other-object" to "unlabeled" ----------------mapped
-        252: 1,  # "moving-car" to "car" ------------------------mapped
-        253: 7,  # "moving-bicyclist" to "bicyclist" ------------mapped
-        254: 6,  # "moving-person" to "person" ------------------mapped
-        255: 8,  # "moving-motorcyclist" to "motorcyclist" ------mapped
-        256: 5,  # "moving-on-rails" mapped to "other-vehic------mapped
-        257: 5,  # "moving-bus" mapped to "other-vehicle" -------mapped
-        258: 4,  # "moving-truck" to "truck" --------------------mapped
-        259: 5  # "moving-other"-vehicle to "other-vehicle"-----mapped
+        0: 19,  # "unlabeled"
+        1: 19,  # "outlier" mapped to "unlabeled" --------------mapped
+        10: 0,  # "car"
+        11: 1,  # "bicycle"
+        13: 4,  # "bus" mapped to "other-vehicle" --------------mapped
+        15: 2,  # "motorcycle"
+        16: 4,  # "on-rails" mapped to "other-vehicle" ---------mapped
+        18: 3,  # "truck"
+        20: 4,  # "other-vehicle"
+        30: 5,  # "person"
+        31: 6,  # "bicyclist"
+        32: 7,  # "motorcyclist"
+        40: 8,  # "road"
+        44: 9,  # "parking"
+        48: 10,  # "sidewalk"
+        49: 11,  # "other-ground"
+        50: 12,  # "building"
+        51: 13,  # "fence"
+        52: 19,  # "other-structure" mapped to "unlabeled" ------mapped
+        60: 8,  # "lane-marking" to "road" ---------------------mapped
+        70: 14,  # "vegetation"
+        71: 15,  # "trunk"
+        72: 16,  # "terrain"
+        80: 17,  # "pole"
+        81: 18,  # "traffic-sign"
+        99: 19,  # "other-object" to "unlabeled" ----------------mapped
+        252: 0,  # "moving-car" to "car" ------------------------mapped
+        253: 6,  # "moving-bicyclist" to "bicyclist" ------------mapped
+        254: 5,  # "moving-person" to "person" ------------------mapped
+        255: 7,  # "moving-motorcyclist" to "motorcyclist" ------mapped
+        256: 4,  # "moving-on-rails" mapped to "other-vehic------mapped
+        257: 4,  # "moving-bus" mapped to "other-vehicle" -------mapped
+        258: 3,  # "moving-truck" to "truck" --------------------mapped
+        259: 4  # "moving-other"-vehicle to "other-vehicle"-----mapped
     }
     max_label = 259
     modality = dict(use_lidar=True, use_camera=False)
@@ -96,25 +74,24 @@ def _generate_semantickitti_dataset_config():
     data_prefix = dict(
         pts='sequences/00/velodyne', pts_semantic_mask='sequences/00/labels')
 
-    return (data_root, ann_file, classes, palette, data_prefix, pipeline,
-            modality, seg_label_mapping, max_label)
+    return (data_root, ann_file, classes, data_prefix, pipeline, modality,
+            seg_label_mapping, max_label)
 
 
-class TestSemanticKITTIDataset(unittest.TestCase):
+class TestSemanticKittiDataset(unittest.TestCase):
 
     def test_semantickitti(self):
-        (data_root, ann_file, classes, palette, data_prefix, pipeline,
-         modality, seg_label_mapping,
+        (data_root, ann_file, classes, data_prefix, pipeline, modality,
+         seg_label_mapping,
          max_label) = _generate_semantickitti_dataset_config()
 
         register_all_modules()
         np.random.seed(0)
-        semantickitti_dataset = SemanticKITTIDataset(
+        semantickitti_dataset = SemanticKittiDataset(
             data_root,
             ann_file,
             metainfo=dict(
                 classes=classes,
-                palette=palette,
                 seg_label_mapping=seg_label_mapping,
                 max_label=max_label),
             data_prefix=data_prefix,
@@ -129,10 +106,9 @@ def test_semantickitti(self):
         self.assertEqual(points.shape[0], pts_semantic_mask.shape[0])
 
         expected_pts_semantic_mask = np.array([
-            13., 13., 13., 15., 15., 13., 0., 13., 15., 13., 13., 15., 16., 0.,
-            15., 13., 13., 13., 13., 0., 13., 13., 13., 13., 13., 15., 13.,
-            16., 13., 15., 15., 18., 13., 15., 15., 15., 16., 15., 13., 13.,
-            15., 13., 18., 15., 13., 15., 13., 15., 15., 13.
+            12, 12, 12, 14, 14, 12, 19, 12, 14, 12, 12, 14, 15, 19, 14, 12, 12,
+            12, 12, 19, 12, 12, 12, 12, 12, 14, 12, 15, 12, 14, 14, 17, 12, 14,
+            14, 14, 15, 14, 12, 12, 14, 12, 17, 14, 12, 14, 12, 14, 14, 12
         ])
 
         self.assertTrue(
diff --git a/tests/test_datasets/test_transforms/test_transforms_3d.py b/tests/test_datasets/test_transforms/test_transforms_3d.py
index 3d6fd6eac2..94d2e0c55e 100644
--- a/tests/test_datasets/test_transforms/test_transforms_3d.py
+++ b/tests/test_datasets/test_transforms/test_transforms_3d.py
@@ -7,7 +7,7 @@
 from mmengine.testing import assert_allclose
 
 from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
-                              SemanticKITTIDataset)
+                              SemanticKittiDataset)
 from mmdet3d.datasets.transforms import GlobalRotScaleTrans, LaserMix, PolarMix
 from mmdet3d.structures import LiDARPoints
 from mmdet3d.testing import create_data_info_after_loading
@@ -124,32 +124,10 @@ def setUp(self):
                 seg_3d_dtype='np.int32'),
             dict(type='PointSegClassMapping'),
         ]
-        classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
-                   'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
-                   'sidewalk', 'other-ground', 'building', 'fence',
-                   'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
-        palette = [
-            [174, 199, 232],
-            [152, 223, 138],
-            [31, 119, 180],
-            [255, 187, 120],
-            [188, 189, 34],
-            [140, 86, 75],
-            [255, 152, 150],
-            [214, 39, 40],
-            [197, 176, 213],
-            [148, 103, 189],
-            [196, 156, 148],
-            [23, 190, 207],
-            [247, 182, 210],
-            [219, 219, 141],
-            [255, 127, 14],
-            [158, 218, 229],
-            [44, 160, 44],
-            [112, 128, 144],
-            [227, 119, 194],
-            [82, 84, 163],
-        ]
+        classes = ('car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
+                   'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk',
+                   'other-ground', 'building', 'fence', 'vegetation', 'trunck',
+                   'terrian', 'pole', 'traffic-sign')
         seg_label_mapping = {
             0: 0,  # "unlabeled"
             1: 0,  # "outlier" mapped to "unlabeled" --------------mapped
@@ -187,12 +165,11 @@ def setUp(self):
             259: 5  # "moving-other"-vehicle to "other-vehicle"-----mapped
         }
         max_label = 259
-        self.dataset = SemanticKITTIDataset(
+        self.dataset = SemanticKittiDataset(
             './tests/data/semantickitti/',
             'semantickitti_infos.pkl',
             metainfo=dict(
                 classes=classes,
-                palette=palette,
                 seg_label_mapping=seg_label_mapping,
                 max_label=max_label),
             data_prefix=dict(
@@ -242,32 +219,10 @@ def setUp(self):
                 seg_3d_dtype='np.int32'),
             dict(type='PointSegClassMapping'),
         ]
-        classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
-                   'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
-                   'sidewalk', 'other-ground', 'building', 'fence',
-                   'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
-        palette = [
-            [174, 199, 232],
-            [152, 223, 138],
-            [31, 119, 180],
-            [255, 187, 120],
-            [188, 189, 34],
-            [140, 86, 75],
-            [255, 152, 150],
-            [214, 39, 40],
-            [197, 176, 213],
-            [148, 103, 189],
-            [196, 156, 148],
-            [23, 190, 207],
-            [247, 182, 210],
-            [219, 219, 141],
-            [255, 127, 14],
-            [158, 218, 229],
-            [44, 160, 44],
-            [112, 128, 144],
-            [227, 119, 194],
-            [82, 84, 163],
-        ]
+        classes = ('car', 'bicycle', 'motorcycle', 'truck', 'bus', 'person',
+                   'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk',
+                   'other-ground', 'building', 'fence', 'vegetation', 'trunck',
+                   'terrian', 'pole', 'traffic-sign')
         seg_label_mapping = {
             0: 0,  # "unlabeled"
             1: 0,  # "outlier" mapped to "unlabeled" --------------mapped
@@ -305,12 +260,11 @@ def setUp(self):
             259: 5  # "moving-other"-vehicle to "other-vehicle"-----mapped
         }
         max_label = 259
-        self.dataset = SemanticKITTIDataset(
+        self.dataset = SemanticKittiDataset(
             './tests/data/semantickitti/',
             'semantickitti_infos.pkl',
             metainfo=dict(
                 classes=classes,
-                palette=palette,
                 seg_label_mapping=seg_label_mapping,
                 max_label=max_label),
             data_prefix=dict(
diff --git a/tools/dataset_converters/semantickitti_converter.py b/tools/dataset_converters/semantickitti_converter.py
index 4df419f6e4..2454eea6f9 100644
--- a/tools/dataset_converters/semantickitti_converter.py
+++ b/tools/dataset_converters/semantickitti_converter.py
@@ -62,7 +62,9 @@ def get_semantickitti_info(split):
                     'lidar_path':
                     osp.join('sequences',
                              str(i_folder).zfill(2), 'velodyne',
-                             str(j).zfill(6) + '.bin')
+                             str(j).zfill(6) + '.bin'),
+                    'num_pts_feats':
+                    4
                 },
                 'pts_semantic_mask_path':
                 osp.join('sequences',