Skip to content

Commit

Permalink
new terrain evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
luisa-mao committed May 28, 2024
1 parent 7366b6f commit 64c5217
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 1,335 deletions.
2 changes: 2 additions & 0 deletions src/navigation/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.png
*.bag
114 changes: 114 additions & 0 deletions src/navigation/load_model_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include <torch/script.h> // One-stop header.
#include <opencv2/opencv.hpp>

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
// load an image
cv::Mat3b img = cv::imread("13_140.png", cv::IMREAD_COLOR);
std::cout << "Image rows: " << img.rows << std::endl;
std::cout << "Image cols: " << img.cols << std::endl;
// resize width to 256 and height to 128
cv::resize(img, img, cv::Size(256, 128));
// print the rows and cols
std::cout << "Image rows: " << img.rows << std::endl;
std::cout << "Image cols: " << img.cols << std::endl;
// bgr to rgb
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
// transpose to 3x128x256
// img = img.t();
// reshape to 1x3x128x256
auto img_tensor = torch::from_blob(img.data, {1, img.rows, img.cols, 3}, torch::kByte);
img_tensor = img_tensor.permute({0, 3, 1, 2});
// convert to float between 0 and 1
img_tensor = img_tensor.to(torch::kFloat32);
img_tensor = img_tensor / 255.0;

// print the shape
std::cout << "Image shape: ";
for (auto& size : img_tensor.sizes()) {
std::cout << size << " ";
}

// print the min and max
std::cout << std::endl;
std::cout << "Image min: " << img_tensor.min().item<float>() << std::endl;
std::cout << "Image max: " << img_tensor.max().item<float>() << std::endl;



if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}


torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);

std::cout << "Model loaded successfully!" << std::endl;
torch::NoGradGuard no_grad;

// Create example_bev tensor
// auto example_bev = torch::randn({1, 3, 128, 256});

// Create example_context tensor
auto example_context = torch::randn({1, 9, 128, 64});

std::cout << "Example tensors created!" << std::endl;

// Create a vector of inputs
std::vector<torch::jit::IValue> inputs;
inputs.push_back(example_context);
inputs.push_back(img_tensor);

// auto i = torch::jit::IValue(std::make_tuple(example_context, example_bev));
// inputs.push_back(i);

std::cout << "forward inference started!" << std::endl;

// Get all the methods of the module
auto methods = module.get_methods();

// Print all the methods
for (const auto& method : methods) {
std::cout << "Method: " << method.name() << std::endl;
}
// Run the model with the inputs
auto output = module.forward(inputs).toTensor();

// apply sigmoid to the output
output = torch::sigmoid(output);

// Print the shape of the output
std::cout << "Output shape: ";
for (auto& size : output.sizes()) {
std::cout << size << " ";
}
std::cout << std::endl;

// Print the min and max of the output
std::cout << "Output min: " << output.min().item<float>() << std::endl;
std::cout << "Output max: " << output.max().item<float>() << std::endl;

// write the output to a grayscale image
cv::Mat1f output_mat(output.size(2), output.size(3), output.data_ptr<float>());
// save to disk
cv::imwrite("output.png", output_mat * 255);




}
catch (const c10::Error& e) {
std::cerr << "an error occurred\n";
// print the error
std::cerr << e.what();
return -1;
}

std::cout << "ok\n";
}
7 changes: 6 additions & 1 deletion src/navigation/navigation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "deep_cost_map_evaluator.h"
#include "linear_evaluator.h"
#include "terrain_evaluator.h"
#include "terrain_evaluator2.h"

using Eigen::Rotation2Df;
using Eigen::Vector2f;
Expand Down Expand Up @@ -169,6 +170,10 @@ void Navigation::Initialize(const NavigationParameters& params,
auto terrain_evaluator = std::make_shared<TerrainEvaluator>();
terrain_evaluator->LoadModel();
evaluator_ = terrain_evaluator;
} else if (params_.evaluator_type == "terrain2") {
auto terrain_evaluator = std::make_shared<CustomTerrainEvaluator>();
terrain_evaluator->LoadModel();
evaluator_ = terrain_evaluator;
} else if (params_.evaluator_type == "linear") {
evaluator_ = std::make_shared<LinearEvaluator>();
} else {
Expand Down Expand Up @@ -809,7 +814,7 @@ vector<std::shared_ptr<PathRolloutBase>> Navigation::GetLastPathOptions() {
const cv::Mat& Navigation::GetVisualizationImage() {
if (params_.evaluator_type == "cost_map") {
return std::dynamic_pointer_cast<DeepCostMapEvaluator>(evaluator_)->latest_vis_image_;
} else if (params_.evaluator_type == "terrain") {
} else if (params_.evaluator_type == "terrain" || params_.evaluator_type == "terrain2") {
return std::dynamic_pointer_cast<TerrainEvaluator>(evaluator_)->latest_vis_image_;
} else {
std::cerr << "No visualization image for linear evaluator" << std::endl;
Expand Down
8 changes: 4 additions & 4 deletions src/navigation/navigation_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -860,11 +860,11 @@ int main(int argc, char** argv) {
std_msgs::Header viz_img_header; // empty viz_img_header
viz_img_header.stamp = ros::Time::now(); // time
cv_bridge::CvImage viz_img;
if (params.evaluator_type == "cost_map" || params.evaluator_type == "terrain") {
if (params.evaluator_type == "cost_map" || params.evaluator_type == "terrain" || params.evaluator_type == "terrain2") {
viz_img = cv_bridge::CvImage(viz_img_header, sensor_msgs::image_encodings::RGB8, navigation_.GetVisualizationImage());
}
cv_bridge::CvImage cost_img;
if (params.evaluator_type == "terrain") {
if (params.evaluator_type == "terrain" || params.evaluator_type == "terrain2") {
cost_img = cv_bridge::CvImage(std_msgs::Header(), sensor_msgs::image_encodings::RGB8, navigation_.GetCostMapImage());
}

Expand Down Expand Up @@ -895,11 +895,11 @@ int main(int argc, char** argv) {
global_viz_msg_.header.stamp = ros::Time::now();
viz_pub_.publish(local_viz_msg_);
viz_pub_.publish(global_viz_msg_);
if (params.evaluator_type == "cost_map" || params.evaluator_type == "terrain") {
if (params.evaluator_type == "cost_map" || params.evaluator_type == "terrain" || params.evaluator_type == "terrain2") {
viz_img.image = navigation_.GetVisualizationImage();
viz_img_pub_.publish(viz_img.toImageMsg());
}
if (params.evaluator_type == "terrain") {
if (params.evaluator_type == "terrain" || params.evaluator_type == "terrain2") {
cost_img.image = navigation_.GetCostMapImage();
cost_img_pub_.publish(cost_img.toImageMsg());
}
Expand Down
Loading

0 comments on commit 64c5217

Please sign in to comment.