-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add vanilla train/val/test codes of SemanticKITTI/nuScenes/Waymo
- Loading branch information
Showing
36 changed files
with
4,732 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
# @file: __init__.py.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
# @file: data_builder.py | ||
|
||
import torch | ||
from dataloader.dataset_semantickitti import get_model_class, collate_fn_BEV, collate_fn_BEV_tta, collate_fn_BEV_ms, collate_fn_BEV_ms_tta | ||
from dataloader.pc_dataset import get_pc_model_class | ||
|
||
|
||
def build(dataset_config, | ||
train_dataloader_config, | ||
val_dataloader_config, | ||
grid_size=[480, 360, 32], | ||
use_tta=False, | ||
use_multiscan=False, | ||
use_waymo=False): | ||
data_path = train_dataloader_config["data_path"] | ||
train_imageset = train_dataloader_config["imageset"] | ||
val_imageset = val_dataloader_config["imageset"] | ||
train_ref = train_dataloader_config["return_ref"] | ||
val_ref = val_dataloader_config["return_ref"] | ||
|
||
label_mapping = dataset_config["label_mapping"] | ||
|
||
SemKITTI = get_pc_model_class(dataset_config['pc_dataset_type']) | ||
|
||
nusc=None | ||
if "nusc" in dataset_config['pc_dataset_type']: | ||
from nuscenes import NuScenes | ||
nusc = NuScenes(version='v1.0-trainval', dataroot=data_path, verbose=True) | ||
|
||
train_pt_dataset = SemKITTI(data_path, imageset=train_imageset, | ||
return_ref=train_ref, label_mapping=label_mapping, nusc=nusc) | ||
val_pt_dataset = SemKITTI(data_path, imageset=val_imageset, | ||
return_ref=val_ref, label_mapping=label_mapping, nusc=nusc) | ||
|
||
train_dataset = get_model_class(dataset_config['dataset_type'])( | ||
train_pt_dataset, | ||
grid_size=grid_size, | ||
flip_aug=True, | ||
fixed_volume_space=dataset_config['fixed_volume_space'], | ||
max_volume_space=dataset_config['max_volume_space'], | ||
min_volume_space=dataset_config['min_volume_space'], | ||
ignore_label=dataset_config["ignore_label"], | ||
rotate_aug=True, | ||
scale_aug=True, | ||
transform_aug=True | ||
) | ||
|
||
if use_tta: | ||
val_dataset = get_model_class(dataset_config['dataset_type'])( | ||
val_pt_dataset, | ||
grid_size=grid_size, | ||
flip_aug=True, | ||
fixed_volume_space=dataset_config['fixed_volume_space'], | ||
max_volume_space=dataset_config['max_volume_space'], | ||
min_volume_space=dataset_config['min_volume_space'], | ||
ignore_label=dataset_config["ignore_label"], | ||
rotate_aug=True, | ||
scale_aug=True, | ||
return_test=True, | ||
use_tta=True, | ||
) | ||
if use_multiscan: | ||
collate_fn_BEV_tmp = collate_fn_BEV_ms_tta | ||
else: | ||
collate_fn_BEV_tmp = collate_fn_BEV_tta | ||
else: | ||
val_dataset = get_model_class(dataset_config['dataset_type'])( | ||
val_pt_dataset, | ||
grid_size=grid_size, | ||
fixed_volume_space=dataset_config['fixed_volume_space'], | ||
max_volume_space=dataset_config['max_volume_space'], | ||
min_volume_space=dataset_config['min_volume_space'], | ||
ignore_label=dataset_config["ignore_label"], | ||
) | ||
if use_multiscan or use_waymo: | ||
collate_fn_BEV_tmp = collate_fn_BEV_ms | ||
else: | ||
collate_fn_BEV_tmp = collate_fn_BEV | ||
|
||
train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
batch_size=train_dataloader_config["batch_size"], | ||
collate_fn=collate_fn_BEV_tmp, | ||
shuffle=train_dataloader_config["shuffle"], | ||
num_workers=train_dataloader_config["num_workers"]) | ||
val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset, | ||
batch_size=val_dataloader_config["batch_size"], | ||
collate_fn=collate_fn_BEV_tmp, | ||
shuffle=val_dataloader_config["shuffle"], | ||
num_workers=val_dataloader_config["num_workers"]) | ||
|
||
if use_tta: | ||
return train_dataset_loader, val_dataset_loader, val_pt_dataset | ||
else: | ||
return train_dataset_loader, val_dataset_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
# @file: loss_builder.py | ||
|
||
import torch | ||
from utils.lovasz_losses import lovasz_softmax | ||
|
||
|
||
def build(wce=True, lovasz=True, num_class=20, ignore_label=0): | ||
|
||
loss_funs = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) | ||
|
||
if wce and lovasz: | ||
return loss_funs, lovasz_softmax | ||
elif wce and not lovasz: | ||
return wce | ||
elif not wce and lovasz: | ||
return lovasz_softmax | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
# @file: model_builder.py | ||
|
||
from network.cylinder_spconv_3d import get_model_class | ||
from network.segmentator_3d_asymm_spconv import Asymm_3d_spconv | ||
from network.cylinder_fea_generator import cylinder_fea | ||
|
||
|
||
def build(model_config): | ||
output_shape = model_config['output_shape'] | ||
num_class = model_config['num_class'] | ||
num_input_features = model_config['num_input_features'] | ||
use_norm = model_config['use_norm'] | ||
init_size = model_config['init_size'] | ||
fea_dim = model_config['fea_dim'] | ||
out_fea_dim = model_config['out_fea_dim'] | ||
|
||
cylinder_3d_spconv_seg = Asymm_3d_spconv( | ||
output_shape=output_shape, | ||
use_norm=use_norm, | ||
num_input_features=num_input_features, | ||
init_size=init_size, | ||
nclasses=num_class) | ||
|
||
cy_fea_net = cylinder_fea(grid_size=output_shape, | ||
fea_dim=fea_dim, | ||
out_pt_fea_dim=out_fea_dim, | ||
fea_compre=num_input_features) | ||
|
||
model = get_model_class(model_config["model_architecture"])( | ||
cylin_model=cy_fea_net, | ||
segmentator_spconv=cylinder_3d_spconv_seg, | ||
sparse_shape=output_shape | ||
) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
# @file: __init__.py.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# -*- coding:utf-8 -*- | ||
# author: Xinge | ||
|
||
from pathlib import Path | ||
|
||
from strictyaml import Bool, Float, Int, Map, Seq, Str, as_document, load | ||
|
||
model_params = Map( | ||
{ | ||
"model_architecture": Str(), | ||
"output_shape": Seq(Int()), | ||
"fea_dim": Int(), | ||
"out_fea_dim": Int(), | ||
"num_class": Int(), | ||
"num_input_features": Int(), | ||
"use_norm": Bool(), | ||
"init_size": Int(), | ||
} | ||
) | ||
|
||
dataset_params = Map( | ||
{ | ||
"dataset_type": Str(), | ||
"pc_dataset_type": Str(), | ||
"ignore_label": Int(), | ||
"return_test": Bool(), | ||
"fixed_volume_space": Bool(), | ||
"label_mapping": Str(), | ||
"max_volume_space": Seq(Float()), | ||
"min_volume_space": Seq(Float()), | ||
} | ||
) | ||
|
||
|
||
train_data_loader = Map( | ||
{ | ||
"data_path": Str(), | ||
"imageset": Str(), | ||
"return_ref": Bool(), | ||
"batch_size": Int(), | ||
"shuffle": Bool(), | ||
"num_workers": Int(), | ||
} | ||
) | ||
|
||
val_data_loader = Map( | ||
{ | ||
"data_path": Str(), | ||
"imageset": Str(), | ||
"return_ref": Bool(), | ||
"batch_size": Int(), | ||
"shuffle": Bool(), | ||
"num_workers": Int(), | ||
} | ||
) | ||
|
||
|
||
train_params = Map( | ||
{ | ||
"model_load_path": Str(), | ||
"model_save_path": Str(), | ||
"checkpoint_every_n_steps": Int(), | ||
"max_num_epochs": Int(), | ||
"eval_every_n_steps": Int(), | ||
"learning_rate": Float() | ||
} | ||
) | ||
|
||
schema_v4 = Map( | ||
{ | ||
"format_version": Int(), | ||
"model_params": model_params, | ||
"dataset_params": dataset_params, | ||
"train_data_loader": train_data_loader, | ||
"val_data_loader": val_data_loader, | ||
"train_params": train_params, | ||
} | ||
) | ||
|
||
|
||
SCHEMA_FORMAT_VERSION_TO_SCHEMA = {4: schema_v4} | ||
|
||
|
||
def load_config_data(path: str) -> dict: | ||
yaml_string = Path(path).read_text() | ||
cfg_without_schema = load(yaml_string, schema=None) | ||
schema_version = int(cfg_without_schema["format_version"]) | ||
if schema_version not in SCHEMA_FORMAT_VERSION_TO_SCHEMA: | ||
raise Exception(f"Unsupported schema format version: {schema_version}.") | ||
|
||
strict_cfg = load(yaml_string, schema=SCHEMA_FORMAT_VERSION_TO_SCHEMA[schema_version]) | ||
cfg: dict = strict_cfg.data | ||
return cfg | ||
|
||
|
||
def config_data_to_config(data): # type: ignore | ||
return as_document(data, schema_v4) | ||
|
||
|
||
def save_config_data(data: dict, path: str) -> None: | ||
cfg_document = config_data_to_config(data) | ||
with open(Path(path), "w") as f: | ||
f.write(cfg_document.as_yaml()) |
Oops, something went wrong.