diff --git a/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/crop_map_operation.h b/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/crop_map_operation.h index 7ecfce9c4..6a8b5867f 100644 --- a/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/crop_map_operation.h +++ b/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/crop_map_operation.h @@ -3,10 +3,10 @@ #include #include -#include #include #include +#include #include #include @@ -48,6 +48,7 @@ class CropMapOperation : public MapOperationBase { public: CropMapOperation(const CropMapOperationConfig& config, MapBase::Ptr occupancy_map, + std::shared_ptr thread_pool, std::shared_ptr transformer, std::string world_frame); @@ -57,6 +58,7 @@ class CropMapOperation : public MapOperationBase { private: const CropMapOperationConfig config_; + const std::shared_ptr thread_pool_; const std::shared_ptr transformer_; const std::string world_frame_; ros::Time last_run_timestamp_; diff --git a/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/decay_map_operation.h b/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/decay_map_operation.h index 43f8bbd61..0662dd383 100644 --- a/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/decay_map_operation.h +++ b/interfaces/ros1/wavemap_ros/include/wavemap_ros/map_operations/decay_map_operation.h @@ -1,11 +1,12 @@ #ifndef WAVEMAP_ROS_MAP_OPERATIONS_DECAY_MAP_OPERATION_H_ #define WAVEMAP_ROS_MAP_OPERATIONS_DECAY_MAP_OPERATION_H_ -#include +#include #include #include #include +#include #include #include @@ -29,9 +30,8 @@ struct DecayMapOperationConfig : public ConfigBase { class DecayMapOperation : public MapOperationBase { public: DecayMapOperation(const DecayMapOperationConfig& config, - MapBase::Ptr occupancy_map) - : MapOperationBase(std::move(occupancy_map)), - config_(config.checkValid()) {} + MapBase::Ptr occupancy_map, + std::shared_ptr thread_pool); bool shouldRun(const ros::Time& current_time = ros::Time::now()); @@ -39,6 +39,7 @@ class DecayMapOperation : public MapOperationBase { private: const DecayMapOperationConfig config_; + const std::shared_ptr thread_pool_; ros::Time last_run_timestamp_; Stopwatch timer_; }; diff --git a/interfaces/ros1/wavemap_ros/src/map_operations/crop_map_operation.cc b/interfaces/ros1/wavemap_ros/src/map_operations/crop_map_operation.cc index a4a2f490d..6c8ea052a 100644 --- a/interfaces/ros1/wavemap_ros/src/map_operations/crop_map_operation.cc +++ b/interfaces/ros1/wavemap_ros/src/map_operations/crop_map_operation.cc @@ -31,10 +31,12 @@ bool CropMapOperationConfig::isValid(bool verbose) const { CropMapOperation::CropMapOperation(const CropMapOperationConfig& config, MapBase::Ptr occupancy_map, + std::shared_ptr thread_pool, std::shared_ptr transformer, std::string world_frame) : MapOperationBase(std::move(occupancy_map)), config_(config.checkValid()), + thread_pool_(std::move(thread_pool)), transformer_(std::move(transformer)), world_frame_(std::move(world_frame)) {} @@ -85,13 +87,14 @@ void CropMapOperation::run(bool force_run) { dynamic_cast(occupancy_map_.get()); hashed_wavelet_octree) { crop_to_sphere(T_W_B->getPosition(), config_.radius, *hashed_wavelet_octree, - termination_height_); + termination_height_, thread_pool_); } else if (auto* hashed_chunked_wavelet_octree = dynamic_cast( occupancy_map_.get()); hashed_chunked_wavelet_octree) { crop_to_sphere(T_W_B->getPosition(), config_.radius, - *hashed_chunked_wavelet_octree, termination_height_); + *hashed_chunked_wavelet_octree, termination_height_, + thread_pool_); } else { ROS_WARN( "Map cropping is only supported for hash-based map data structures."); diff --git a/interfaces/ros1/wavemap_ros/src/map_operations/decay_map_operation.cc b/interfaces/ros1/wavemap_ros/src/map_operations/decay_map_operation.cc index 38a710e34..029ab16f5 100644 --- a/interfaces/ros1/wavemap_ros/src/map_operations/decay_map_operation.cc +++ b/interfaces/ros1/wavemap_ros/src/map_operations/decay_map_operation.cc @@ -1,5 +1,8 @@ #include "wavemap_ros/map_operations/decay_map_operation.h" +#include +#include + #include #include @@ -20,6 +23,13 @@ bool DecayMapOperationConfig::isValid(bool verbose) const { return all_valid; } +DecayMapOperation::DecayMapOperation(const DecayMapOperationConfig& config, + MapBase::Ptr occupancy_map, + std::shared_ptr thread_pool) + : MapOperationBase(std::move(occupancy_map)), + config_(config.checkValid()), + thread_pool_(std::move(thread_pool)) {} + bool DecayMapOperation::shouldRun(const ros::Time& current_time) { return config_.once_every < (current_time - last_run_timestamp_).toSec(); } @@ -41,12 +51,12 @@ void DecayMapOperation::run(bool force_run) { if (auto* hashed_wavelet_octree = dynamic_cast(occupancy_map_.get()); hashed_wavelet_octree) { - multiply(*hashed_wavelet_octree, config_.decay_rate); + multiply(*hashed_wavelet_octree, config_.decay_rate, thread_pool_); } else if (auto* hashed_chunked_wavelet_octree = dynamic_cast( occupancy_map_.get()); hashed_chunked_wavelet_octree) { - multiply(*hashed_chunked_wavelet_octree, config_.decay_rate); + multiply(*hashed_chunked_wavelet_octree, config_.decay_rate, thread_pool_); } else { ROS_WARN("Map decay is only supported for hash-based map data structures."); } diff --git a/interfaces/ros1/wavemap_ros/src/map_operations/map_ros_operation_factory.cc b/interfaces/ros1/wavemap_ros/src/map_operations/map_ros_operation_factory.cc index 3a02f0f3d..4e68d5cf2 100644 --- a/interfaces/ros1/wavemap_ros/src/map_operations/map_ros_operation_factory.cc +++ b/interfaces/ros1/wavemap_ros/src/map_operations/map_ros_operation_factory.cc @@ -59,16 +59,16 @@ std::unique_ptr MapRosOperationFactory::create( case MapRosOperationType::kCropMap: if (const auto config = CropMapOperationConfig::from(params); config) { return std::make_unique( - config.value(), std::move(occupancy_map), std::move(transformer), - std::move(world_frame)); + config.value(), std::move(occupancy_map), std::move(thread_pool), + std::move(transformer), std::move(world_frame)); } else { ROS_ERROR("Crop map operation config could not be loaded."); return nullptr; } case MapRosOperationType::kDecayMap: if (const auto config = DecayMapOperationConfig::from(params); config) { - return std::make_unique(config.value(), - std::move(occupancy_map)); + return std::make_unique( + config.value(), std::move(occupancy_map), std::move(thread_pool)); } else { ROS_ERROR("Decay map operation config could not be loaded."); return nullptr; diff --git a/library/cpp/include/wavemap/core/utils/edit/crop.h b/library/cpp/include/wavemap/core/utils/edit/crop.h index 3b21b63e3..a7a2e2053 100644 --- a/library/cpp/include/wavemap/core/utils/edit/crop.h +++ b/library/cpp/include/wavemap/core/utils/edit/crop.h @@ -1,7 +1,10 @@ #ifndef WAVEMAP_CORE_UTILS_EDIT_CROP_H_ #define WAVEMAP_CORE_UTILS_EDIT_CROP_H_ +#include + #include "wavemap/core/common.h" +#include "wavemap/core/utils/thread_pool.h" namespace wavemap { template @@ -88,10 +91,13 @@ void cropNodeRecursive(typename MapType::Block::OctreeType::NodeRefType node, template void crop_to_sphere(const Point3D& t_W_center, FloatingPoint radius, - MapType& map, IndexElement termination_height) { + MapType& map, IndexElement termination_height, + const std::shared_ptr& thread_pool = nullptr) { + using NodePtrType = typename MapType::Block::OctreeType::NodePtrType; const IndexElement tree_height = map.getTreeHeight(); const FloatingPoint min_cell_width = map.getMinCellWidth(); + // Check all blocks for (auto it = map.getHashMap().begin(); it != map.getHashMap().end();) { // Start by testing at the block level const Index3D& block_index = it->first; @@ -112,13 +118,33 @@ void crop_to_sphere(const Point3D& t_W_center, FloatingPoint radius, // Since the block overlaps with the sphere's boundary, we need to process // it at a higher resolution by recursing over its cells auto& block = it->second; - cropNodeRecursive(block.getRootNode(), block_node_index, - block.getRootScale(), t_W_center, radius, - min_cell_width, termination_height); + // Indicate that the block has changed block.setLastUpdatedStamp(); - + // Get pointers to the root value and node, which contain the wavelet + // scale and detail coefficients, respectively + FloatingPoint* root_value_ptr = &block.getRootScale(); + NodePtrType root_node_ptr = &block.getRootNode(); + // Recursively crop all nodes + if (thread_pool) { + thread_pool->add_task([root_node_ptr, root_value_ptr, block_node_index, + t_W_center, radius, min_cell_width, + termination_height]() { + cropNodeRecursive(*root_node_ptr, block_node_index, + *root_value_ptr, t_W_center, radius, + min_cell_width, termination_height); + }); + } else { + cropNodeRecursive(*root_node_ptr, block_node_index, + *root_value_ptr, t_W_center, radius, + min_cell_width, termination_height); + } + // Advance to the next block ++it; } + // Wait for all parallel jobs to finish + if (thread_pool) { + thread_pool->wait_all(); + } } } // namespace wavemap diff --git a/library/cpp/include/wavemap/core/utils/edit/multiply.h b/library/cpp/include/wavemap/core/utils/edit/multiply.h index 794addc9a..a2e237494 100644 --- a/library/cpp/include/wavemap/core/utils/edit/multiply.h +++ b/library/cpp/include/wavemap/core/utils/edit/multiply.h @@ -1,7 +1,10 @@ #ifndef WAVEMAP_CORE_UTILS_EDIT_MULTIPLY_H_ #define WAVEMAP_CORE_UTILS_EDIT_MULTIPLY_H_ +#include + #include "wavemap/core/common.h" +#include "wavemap/core/utils/thread_pool.h" namespace wavemap { @@ -21,12 +24,31 @@ void multiplyNodeRecursive( } template -void multiply(MapType& map, FloatingPoint multiplier) { - map.forEachBlock([multiplier](const Index3D& /*block_index*/, auto& block) { - block.getRootScale() *= multiplier; - multiplyNodeRecursive(block.getRootNode(), multiplier); +void multiply(MapType& map, FloatingPoint multiplier, + const std::shared_ptr& thread_pool = nullptr) { + using NodePtrType = typename MapType::Block::OctreeType::NodePtrType; + + // Process all blocks + for (auto& [block_index, block] : map.getHashMap()) { + // Indicate that the block has changed block.setLastUpdatedStamp(); - }); + // Multiply the block's average value (wavelet scale coefficient) + FloatingPoint& root_value = block.getRootScale(); + root_value *= multiplier; + // Recursively multiply all node values (wavelet detail coefficients) + NodePtrType root_node_ptr = &block.getRootNode(); + if (thread_pool) { + thread_pool->add_task([root_node_ptr, multiplier]() { + multiplyNodeRecursive(*root_node_ptr, multiplier); + }); + } else { + multiplyNodeRecursive(*root_node_ptr, multiplier); + } + } + // Wait for all parallel jobs to finish + if (thread_pool) { + thread_pool->wait_all(); + } } } // namespace wavemap