Skip to content

Commit

Permalink
Relative imports
Browse files Browse the repository at this point in the history
  • Loading branch information
rpautrat committed Feb 11, 2022
1 parent 492a71d commit 3b340f0
Show file tree
Hide file tree
Showing 54 changed files with 95 additions and 90 deletions.
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ pip install -r requirements.txt

Set your dataset and experiment paths (where you will store your datasets and checkpoints of your experiments) by modifying the file `config/project_config.py`. Both variables `DATASET_ROOT` and `EXP_PATH` have to be set.

Install the Python package:
```bash
pip install -e .
```

You can download the version of the [Wireframe dataset](https://github.com/huangkuns/wireframe) that we used during our training and testing [here](https://www.polybox.ethz.ch/index.php/s/IfdEf7RoHol7jeg). This repository also includes some files to train on the [Holicity dataset](https://holicity.io/) to add more outdoor images, but note that we did not extensively test this dataset and the original paper was based on the Wireframe dataset only.

### Training your own model
Expand All @@ -34,7 +39,7 @@ All training parameters are located in configuration files in the folder `config

The following command will create the synthetic dataset and start training the model on it:
```bash
python experiment.py --mode train --dataset_config config/synthetic_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_synth
python -m sold2.experiment --mode train --dataset_config config/synthetic_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_synth
```
</details>

Expand All @@ -43,22 +48,20 @@ python experiment.py --mode train --dataset_config config/synthetic_dataset.yaml

Note that this step can take one to several days depending on your machine and on the size of the dataset. You can set the batch size to the maximum capacity that your GPU can handle. Prior to this step, make sure that the dataset config file `config/wireframe_dataset.yaml` has the lines `gt_source_train` and `gt_source_test` commented and you should also disable the photometric and homographic augmentations.
```bash
python experiment.py --exp_name wireframe_train --mode export --resume_path <path to your previously trained sold2_synth> --model_config config/train_detector.yaml --dataset_config config/wireframe_dataset.yaml --checkpoint_name <name of the best checkpoint> --export_dataset_mode train --export_batch_size 4
python -m sold2.experiment --exp_name wireframe_train --mode export --resume_path <path to your previously trained sold2_synth> --model_config config/train_detector.yaml --dataset_config config/wireframe_dataset.yaml --checkpoint_name <name of the best checkpoint> --export_dataset_mode train --export_batch_size 4
```

You can similarly perform the same for the test set:
```bash
python experiment.py --exp_name wireframe_test --mode export --resume_path <path to your previously trained sold2_synth> --model_config config/train_detector.yaml --dataset_config config/wireframe_dataset.yaml --checkpoint_name <name of the best checkpoint> --export_dataset_mode test --export_batch_size 4
python -m sold2.experiment --exp_name wireframe_test --mode export --resume_path <path to your previously trained sold2_synth> --model_config config/train_detector.yaml --dataset_config config/wireframe_dataset.yaml --checkpoint_name <name of the best checkpoint> --export_dataset_mode test --export_batch_size 4
```
</details>

<details>
<summary><b>Step3: Compute the ground truth line segments from the raw data</b></summary>

```bash
cd postprocess
python convert_homography_results.py <name of the previously exported file (e.g. "wireframe_train.h5")> <name of the new data with extracted line segments (e.g. "wireframe_train_gt.h5")> ../config/export_line_features.yaml
cd ..
python -m sold2.postprocess.convert_homography_results <name of the previously exported file (e.g. "wireframe_train.h5")> <name of the new data with extracted line segments (e.g. "wireframe_train_gt.h5")> config/export_line_features.yaml
```

We recommend testing the results on a few samples of your dataset to check the quality of the output, and modifying the hyperparameters if need be. Using a `detect_thresh=0.5` and `inlier_thresh=0.99` proved to be successful for the Wireframe dataset in our case for example.
Expand All @@ -70,17 +73,17 @@ We recommend testing the results on a few samples of your dataset to check the q
We found it easier to pretrain the detector alone first, before fine-tuning it with the descriptor part.
Uncomment the lines 'gt_source_train' and 'gt_source_test' in `config/wireframe_dataset.yaml` and fill them with the path to the h5 file generated in the previous step.
```bash
python experiment.py --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe
python -m sold2.experiment --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe
```

Alternatively, you can also fine-tune the already trained synthetic model:
```bash
python experiment.py --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe --pretrained --pretrained_path <path ot the pre-trained sold2_synth> --checkpoint_name <name of the best checkpoint>
python -m sold2.experiment --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe --pretrained --pretrained_path <path ot the pre-trained sold2_synth> --checkpoint_name <name of the best checkpoint>
```

Lastly, you can resume a training that was stopped:
```bash
python experiment.py --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe --resume --resume_path <path to the model to resume> --checkpoint_name <name of the last checkpoint>
python -m sold2.experiment --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_detector.yaml --exp_name sold2_wireframe --resume --resume_path <path to the model to resume> --checkpoint_name <name of the last checkpoint>
```
</details>

Expand All @@ -89,7 +92,7 @@ python experiment.py --mode train --dataset_config config/wireframe_dataset.yaml

You first need to modify the field 'return_type' in `config/wireframe_dataset.yaml` to 'paired_desc'. The following command will then train the full model (detector + descriptor) on the Wireframe dataset:
```bash
python experiment.py --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_full_pipeline.yaml --exp_name sold2_full_wireframe --pretrained --pretrained_path <path ot the pre-trained sold2_wireframe> --checkpoint_name <name of the best checkpoint>
python -m sold2.experiment --mode train --dataset_config config/wireframe_dataset.yaml --model_config config/train_full_pipeline.yaml --exp_name sold2_full_wireframe --pretrained --pretrained_path <path ot the pre-trained sold2_wireframe> --checkpoint_name <name of the best checkpoint>
```
</details>

Expand All @@ -105,7 +108,7 @@ We provide the checkpoints of two pretrained models:

We provide a [notebook](notebooks/match_lines.ipynb) showing how to use the trained model of SOLD². Additionally, you can use the model to export line features (segments and descriptor maps) as follows:
```bash
python export_line_features.py --img_list <list to a txt file containing the path to all the images> --output_folder <path to the output folder> --checkpoint_path <path to your best checkpoint,>
python -m sold2.export_line_features --img_list <list to a txt file containing the path to all the images> --output_folder <path to the output folder> --checkpoint_path <path to your best checkpoint,>
```

You can tune some of the line detection parameters in `config/export_line_features.yaml`, in particular the 'detect_thresh' and 'inlier_thresh' to adapt them to your trained model and type of images. As the line detection can be sensitive to the image resolution, we recommend using it with images in the range 300~800 px per side.
Expand Down
10 changes: 4 additions & 6 deletions notebooks/match_lines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../\")\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import torch\n",
"\n",
"from model.line_matcher import LineMatcher\n",
"from misc.visualize_util import plot_images, plot_lines, plot_line_matches, plot_color_line_matches, plot_keypoints"
"from sold2.model.line_matcher import LineMatcher\n",
"from sold2.misc.visualize_util import plot_images, plot_lines, plot_line_matches, plot_color_line_matches, plot_keypoints"
]
},
{
Expand Down Expand Up @@ -212,7 +210,7 @@
"metadata": {
"file_extension": ".py",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -226,7 +224,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.8.10"
},
"mimetype": "text/x-python",
"name": "python",
Expand Down
16 changes: 7 additions & 9 deletions notebooks/visualize_exported_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../')\n",
"import numpy as np\n",
"import yaml\n",
"\n",
"from dataset.wireframe_dataset import WireframeDataset\n",
"from dataset.holicity_dataset import HolicityDataset\n",
"from dataset.merge_dataset import MergeDataset\n",
"from misc.visualize_util import plot_junctions, plot_line_segments\n",
"from misc.visualize_util import plot_images, plot_keypoints"
"from sold2.dataset.wireframe_dataset import WireframeDataset\n",
"from sold2.dataset.holicity_dataset import HolicityDataset\n",
"from sold2.dataset.merge_dataset import MergeDataset\n",
"from sold2.misc.visualize_util import plot_junctions, plot_line_segments\n",
"from sold2.misc.visualize_util import plot_images, plot_keypoints"
]
},
{
Expand Down Expand Up @@ -378,7 +376,7 @@
"metadata": {
"file_extension": ".py",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -392,7 +390,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.8.10"
},
"mimetype": "text/x-python",
"name": "python",
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from setuptools import setup


setup(name='sold2', version="0.0", packages=['sold2'])
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions dataset/dataset_util.py → sold2/dataset/dataset_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
The interface of initializing different datasets.
"""
from dataset.synthetic_dataset import SyntheticShapes
from dataset.wireframe_dataset import WireframeDataset
from dataset.holicity_dataset import HolicityDataset
from dataset.merge_dataset import MergeDataset
from .synthetic_dataset import SyntheticShapes
from .wireframe_dataset import WireframeDataset
from .holicity_dataset import HolicityDataset
from .merge_dataset import MergeDataset


def get_dataset(mode="train", dataset_cfg=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from torch.utils.data import Dataset
from torchvision import transforms

from config.project_config import Config as cfg
import dataset.transforms.photometric_transforms as photoaug
import dataset.transforms.homographic_transforms as homoaug
from dataset.transforms.utils import random_scaling
from dataset.synthetic_util import get_line_heatmap
from misc.geometry_utils import warp_points, mask_points
from misc.train_utils import parse_h5_data
from ..config.project_config import Config as cfg
from .transforms import photometric_transforms as photoaug
from .transforms import homographic_transforms as homoaug
from .transforms.utils import random_scaling
from .synthetic_util import get_line_heatmap
from ..misc.geometry_utils import warp_points, mask_points
from ..misc.train_utils import parse_h5_data


def holicity_collate_fn(batch):
Expand Down
4 changes: 2 additions & 2 deletions dataset/merge_dataset.py → sold2/dataset/merge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from copy import deepcopy
from torch.utils.data import Dataset

from dataset.wireframe_dataset import WireframeDataset
from dataset.holicity_dataset import HolicityDataset
from .wireframe_dataset import WireframeDataset
from .holicity_dataset import HolicityDataset


class MergeDataset(Dataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from torch.utils.data import Dataset
import torch.utils.data.dataloader as torch_loader

from config.project_config import Config as cfg
from dataset import synthetic_util
import dataset.transforms.photometric_transforms as photoaug
import dataset.transforms.homographic_transforms as homoaug
from misc.train_utils import parse_h5_data
from ..config.project_config import Config as cfg
from . import synthetic_util
from .transforms import photometric_transforms as photoaug
from .transforms import homographic_transforms as homoaug
from ..misc.train_utils import parse_h5_data


def synthetic_collate_fn(batch):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""
import numpy as np
from math import pi
from dataset.synthetic_util import get_line_map, get_line_heatmap

from ..synthetic_util import get_line_map, get_line_heatmap
import cv2
import copy
import shapely.geometry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import cv2
import numpy as np
import shapely.geometry as sg
from dataset.synthetic_util import get_line_map
import dataset.transforms.homographic_transforms as homoaug

from ..synthetic_util import get_line_map
from . import homographic_transforms as homoaug


def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from torch.utils.data import Dataset
from torchvision import transforms

from config.project_config import Config as cfg
import dataset.transforms.photometric_transforms as photoaug
import dataset.transforms.homographic_transforms as homoaug
from dataset.transforms.utils import random_scaling
from dataset.synthetic_util import get_line_heatmap
from misc.train_utils import parse_h5_data
from misc.geometry_utils import warp_points, mask_points
from ..config.project_config import Config as cfg
from .transforms import photometric_transforms as photoaug
from .transforms import homographic_transforms as homoaug
from .transforms.utils import random_scaling
from .synthetic_util import get_line_heatmap
from ..misc.train_utils import parse_h5_data
from ..misc.geometry_utils import warp_points, mask_points


def wireframe_collate_fn(batch):
Expand Down
6 changes: 3 additions & 3 deletions experiment.py → sold2/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import torch

from config.project_config import Config as cfg
from train import train_net
from export import export_predictions, export_homograpy_adaptation
from .config.project_config import Config as cfg
from .train import train_net
from .export import export_predictions, export_homograpy_adaptation


# Pytorch configurations
Expand Down
10 changes: 5 additions & 5 deletions export.py → sold2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from torch.utils.data import DataLoader
from kornia.geometry import warp_perspective

from dataset.dataset_util import get_dataset
from model.model_util import get_model
from misc.train_utils import get_latest_checkpoint
from train import convert_junc_predictions
from dataset.transforms.homographic_transforms import sample_homography
from .dataset.dataset_util import get_dataset
from .model.model_util import get_model
from .misc.train_utils import get_latest_checkpoint
from .train import convert_junc_predictions
from .dataset.transforms.homographic_transforms import sample_homography


def restore_weights(model, state_dict):
Expand Down
4 changes: 2 additions & 2 deletions export_line_features.py → sold2/export_line_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from tqdm import tqdm

from experiment import load_config
from model.line_matcher import LineMatcher
from .experiment import load_config
from .model.line_matcher import LineMatcher


def export_descriptors(images_list, ckpt_path, config, device, extension,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions model/line_detector.py → sold2/model/line_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch
from torch.nn.functional import softmax

from model.model_util import get_model
from model.loss import get_loss_and_weights
from model.line_detection import LineSegmentDetectionModule
from train import convert_junc_predictions
from .model_util import get_model
from .loss import get_loss_and_weights
from .line_detection import LineSegmentDetectionModule
from ..train import convert_junc_predictions


def line_map_to_segments(junctions, line_map):
Expand Down
14 changes: 7 additions & 7 deletions model/line_matcher.py → sold2/model/line_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch.nn.functional as F
from torch.nn.functional import softmax

from model.model_util import get_model
from model.loss import get_loss_and_weights
from model.metrics import super_nms
from model.line_detection import LineSegmentDetectionModule
from model.line_matching import WunschLineMatcher
from train import convert_junc_predictions
from model.line_detector import line_map_to_segments
from .model_util import get_model
from .loss import get_loss_and_weights
from .metrics import super_nms
from .line_detection import LineSegmentDetectionModule
from .line_matching import WunschLineMatcher
from ..train import convert_junc_predictions
from .line_detector import line_map_to_segments


class LineMatcher(object):
Expand Down
2 changes: 1 addition & 1 deletion model/line_matching.py → sold2/model/line_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn.functional as F

from misc.geometry_utils import keypoints_to_grid
from ..misc.geometry_utils import keypoints_to_grid


class WunschLineMatcher(object):
Expand Down
4 changes: 2 additions & 2 deletions model/loss.py → sold2/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.nn.functional as F
from kornia.geometry import warp_perspective

from misc.geometry_utils import (keypoints_to_grid, get_dist_mask,
get_common_line_mask)
from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask,
get_common_line_mask)


def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion model/metrics.py → sold2/model/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from torchvision.ops.boxes import batched_nms

from misc.geometry_utils import keypoints_to_grid
from ..misc.geometry_utils import keypoints_to_grid


class Metrics(object):
Expand Down
8 changes: 4 additions & 4 deletions model/model_util.py → sold2/model/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch.nn as nn
import torch.nn.init as init

from model.nets.backbone import HourglassBackbone, SuperpointBackbone
from model.nets.junction_decoder import SuperpointDecoder
from model.nets.heatmap_decoder import PixelShuffleDecoder
from model.nets.descriptor_decoder import SuperpointDescriptor
from .nets.backbone import HourglassBackbone, SuperpointBackbone
from .nets.junction_decoder import SuperpointDecoder
from .nets.heatmap_decoder import PixelShuffleDecoder
from .nets.descriptor_decoder import SuperpointDescriptor


def get_model(model_cfg=None, loss_weights=None, mode="train"):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 3b340f0

Please sign in to comment.